mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-12 01:40:40 +01:00
127 lines
2.8 KiB
Go
127 lines
2.8 KiB
Go
|
package rosenpass
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"log/slog"
|
||
|
|
||
|
rp "cunicu.li/go-rosenpass"
|
||
|
log "github.com/sirupsen/logrus"
|
||
|
|
||
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||
|
)
|
||
|
|
||
|
type wireGuardPeer struct {
|
||
|
Interface string
|
||
|
PublicKey rp.Key
|
||
|
}
|
||
|
|
||
|
type NetbirdHandler struct {
|
||
|
ifaceName string
|
||
|
client *wgctrl.Client
|
||
|
peers map[rp.PeerID]wireGuardPeer
|
||
|
presharedKey [32]byte
|
||
|
}
|
||
|
|
||
|
func NewNetbirdHandler(preSharedKey *[32]byte, wgIfaceName string) (hdlr *NetbirdHandler, err error) {
|
||
|
hdlr = &NetbirdHandler{
|
||
|
ifaceName: wgIfaceName,
|
||
|
peers: map[rp.PeerID]wireGuardPeer{},
|
||
|
}
|
||
|
|
||
|
if preSharedKey != nil {
|
||
|
hdlr.presharedKey = *preSharedKey
|
||
|
}
|
||
|
|
||
|
if hdlr.client, err = wgctrl.New(); err != nil {
|
||
|
return nil, fmt.Errorf("failed to creat WireGuard client: %w", err)
|
||
|
}
|
||
|
|
||
|
return hdlr, nil
|
||
|
}
|
||
|
|
||
|
func (h *NetbirdHandler) AddPeer(pid rp.PeerID, intf string, pk rp.Key) {
|
||
|
h.peers[pid] = wireGuardPeer{
|
||
|
Interface: intf,
|
||
|
PublicKey: pk,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (h *NetbirdHandler) RemovePeer(pid rp.PeerID) {
|
||
|
delete(h.peers, pid)
|
||
|
}
|
||
|
|
||
|
func (h *NetbirdHandler) HandshakeCompleted(pid rp.PeerID, key rp.Key) {
|
||
|
log.Debug("Handshake complete")
|
||
|
h.outputKey(rp.KeyOutputReasonStale, pid, key)
|
||
|
}
|
||
|
|
||
|
func (h *NetbirdHandler) HandshakeExpired(pid rp.PeerID) {
|
||
|
key, _ := rp.GeneratePresharedKey()
|
||
|
log.Debug("Handshake expired")
|
||
|
h.outputKey(rp.KeyOutputReasonStale, pid, key)
|
||
|
}
|
||
|
|
||
|
func (h *NetbirdHandler) outputKey(_ rp.KeyOutputReason, pid rp.PeerID, psk rp.Key) {
|
||
|
wg, ok := h.peers[pid]
|
||
|
if !ok {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
device, err := h.client.Device(h.ifaceName)
|
||
|
if err != nil {
|
||
|
log.Errorf("Failed to get WireGuard device: %v", err)
|
||
|
return
|
||
|
}
|
||
|
config := []wgtypes.PeerConfig{
|
||
|
{
|
||
|
UpdateOnly: true,
|
||
|
PublicKey: wgtypes.Key(wg.PublicKey),
|
||
|
PresharedKey: (*wgtypes.Key)(&psk),
|
||
|
},
|
||
|
}
|
||
|
for _, peer := range device.Peers {
|
||
|
if peer.PublicKey == wgtypes.Key(wg.PublicKey) {
|
||
|
if publicKeyEmpty(peer.PresharedKey) || peer.PresharedKey == h.presharedKey {
|
||
|
log.Debugf("Restart wireguard connection to peer %s", peer.PublicKey)
|
||
|
config = []wgtypes.PeerConfig{
|
||
|
{
|
||
|
PublicKey: wgtypes.Key(wg.PublicKey),
|
||
|
PresharedKey: (*wgtypes.Key)(&psk),
|
||
|
Endpoint: peer.Endpoint,
|
||
|
AllowedIPs: peer.AllowedIPs,
|
||
|
},
|
||
|
}
|
||
|
err = h.client.ConfigureDevice(wg.Interface, wgtypes.Config{
|
||
|
Peers: []wgtypes.PeerConfig{
|
||
|
{
|
||
|
Remove: true,
|
||
|
PublicKey: wgtypes.Key(wg.PublicKey),
|
||
|
},
|
||
|
},
|
||
|
})
|
||
|
if err != nil {
|
||
|
slog.Debug("Failed to remove peer")
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if err = h.client.ConfigureDevice(wg.Interface, wgtypes.Config{
|
||
|
Peers: config,
|
||
|
}); err != nil {
|
||
|
log.Errorf("Failed to apply rosenpass key: %v", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func publicKeyEmpty(key wgtypes.Key) bool {
|
||
|
for _, b := range key {
|
||
|
if b != 0 {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
}
|