From b5c4802bb9c6ee910710797403ef962264ff2258 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Papp?= Date: Tue, 16 Apr 2024 16:01:25 +0200 Subject: [PATCH] Apply new receiver functions --- client/internal/engine.go | 3 +-- client/internal/peer/conn.go | 35 +++++++++++++------------ client/internal/relay/turn.go | 14 ++++++---- go.mod | 2 +- iface/bind/bind.go | 47 ++++++++++++++++++++++++++-------- iface/bind/receiver_creator.go | 25 ++++++++++++++---- iface/iface.go | 7 +++++ iface/module_linux.go | 30 ++++++++++++---------- iface/tun.go | 1 + iface/tun_kernel_linux.go | 5 ++++ iface/tun_netstack.go | 5 ++++ iface/tun_usp_linux.go | 11 +++++++- 12 files changed, 130 insertions(+), 55 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index a9003f85d..07ffd4f93 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -200,7 +200,6 @@ func NewEngineWithProbes( networkSerial: 0, sshServerFunc: nbssh.DefaultSSHServer, statusRecorder: statusRecorder, - wgProxyFactory: wgproxy.NewFactory(config.WgPort), mgmProbe: mgmProbe, signalProbe: signalProbe, relayProbe: relayProbe, @@ -499,6 +498,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { return fmt.Errorf("faile to open turn relay: %w", err) } e.turnRelay = turnRelay + e.wgInterface.SetRelayConn(e.turnRelay.RelayConn()) // todo update signal } @@ -620,7 +620,6 @@ func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error { var newSTUNs []*stun.URI log.Debugf("got STUNs update from Management Service, updating") for _, s := range stuns { - log.Debugf("-----updated TURN: %s", s.Uri) url, err := stun.ParseURI(s.Uri) if err != nil { return err diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index d7d10874f..32fa51147 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -345,21 +345,28 @@ func (conn *Conn) Open() error { log.Warnf("error while updating the state of peer %s,err: %v", conn.config.Key, err) } - isControlling := conn.config.LocalKey > conn.config.Key + isControlling := conn.config.LocalKey < conn.config.Key if isControlling { + log.Debugf("---- use this peer's tunr connection") err = conn.turnRelay.PunchHole(remoteOfferAnswer.RemoteAddr) if err != nil { log.Errorf("failed to punch hole: %v", err) } - } else { - /* - remoteConn, err := net.Dial("udp", remoteOfferAnswer.RemoteAddr.String()) - if err != nil { - log.Errorf("failed to dial remote peer %s: %v", conn.config.Key, err) + addr, ok := remoteOfferAnswer.RemoteAddr.(*net.UDPAddr) + if !ok { + return fmt.Errorf("failed to cast addr to udp addr") + } + addr.Port = remoteOfferAnswer.WgListenPort + err := conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, addr, conn.config.WgConfig.PreSharedKey) + if err != nil { + if conn.wgProxy != nil { + _ = conn.wgProxy.CloseConn() } - - */ - + // todo close + return err + } + } else { + log.Debugf("---- use remote peer tunr connection") addr, ok := remoteOfferAnswer.RelayedAddr.(*net.UDPAddr) if !ok { return fmt.Errorf("failed to cast addr to udp addr") @@ -414,13 +421,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem defer conn.mu.Unlock() var endpoint net.Addr - log.Debugf("setup relay connection") - conn.wgProxy = conn.wgProxyFactory.GetProxy() - endpoint, err := conn.wgProxy.AddTurnConn(remoteConn) - if err != nil { - return nil, err - } - + endpoint = remoteConn.RemoteAddr() endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) conn.remoteEndpoint = endpointUdpAddr log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) @@ -432,7 +433,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem } } - err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey) + err := conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey) if err != nil { if conn.wgProxy != nil { _ = conn.wgProxy.CloseConn() diff --git a/client/internal/relay/turn.go b/client/internal/relay/turn.go index 75b53135d..771593282 100644 --- a/client/internal/relay/turn.go +++ b/client/internal/relay/turn.go @@ -78,6 +78,15 @@ func (r *PermanentTurn) SrvRefAddr() net.Addr { return r.srvReflexiveAddress } +func (r *PermanentTurn) PunchHole(mappedAddr net.Addr) error { + _, err := r.relayConn.WriteTo([]byte("Hello"), mappedAddr) + return err +} + +func (r *PermanentTurn) RelayConn() net.PacketConn { + return r.relayConn +} + func (r *PermanentTurn) discoverPublicIP() (*net.UDPAddr, error) { addr, err := r.turnClient.SendBindingRequest() if err != nil { @@ -119,11 +128,6 @@ func (r *PermanentTurn) listen() { }() } -func (r *PermanentTurn) PunchHole(mappedAddr net.Addr) error { - _, err := r.relayConn.WriteTo([]byte("Hello"), mappedAddr) - return err -} - func toURL(uri *stun.URI) string { return fmt.Sprintf("%s:%d", uri.Host, uri.Port) } diff --git a/go.mod b/go.mod index 29a1570c8..1284945b5 100644 --- a/go.mod +++ b/go.mod @@ -172,7 +172,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2023 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 -replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed +replace golang.zx2c4.com/wireguard => /home/pzoli/go/src/github.com/netbirdio/wireguard-go replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 diff --git a/iface/bind/bind.go b/iface/bind/bind.go index 35ba27bc0..79eae1a31 100644 --- a/iface/bind/bind.go +++ b/iface/bind/bind.go @@ -20,6 +20,8 @@ type ICEBind struct { transportNet transport.Net udpMux *UniversalUDPMuxDefault + + receiverCreator *receiverCreator } func NewICEBind(transportNet transport.Net) *ICEBind { @@ -28,6 +30,7 @@ func NewICEBind(transportNet transport.Net) *ICEBind { } rc := newReceiverCreator(ib) + ib.receiverCreator = rc ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc) return ib @@ -44,16 +47,22 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { return s.udpMux, nil } -func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc { +func (s *ICEBind) SetTurnConn(conn interface{}) { + s.receiverCreator.setTurnConn(conn) +} + +func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn, netConn net.PacketConn) wgConn.ReceiveFunc { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() - s.udpMux = NewUniversalUDPMuxDefault( - UniversalUDPMuxParams{ - UDPConn: conn, - Net: s.transportNet, - }, - ) + if conn != nil { + s.udpMux = NewUniversalUDPMuxDefault( + UniversalUDPMuxParams{ + UDPConn: conn, + Net: s.transportNet, + }, + ) + } return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { msgs := ipv4MsgsPool.Get().(*[]ipv4.Message) defer ipv4MsgsPool.Put(msgs) @@ -62,9 +71,22 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC } var numMsgs int if runtime.GOOS == "linux" { - numMsgs, err = pc.ReadBatch(*msgs, 0) - if err != nil { - return 0, err + if netConn != nil { + log.Debugf("----read from turn conn...") + msg := &(*msgs)[0] + msg.N, msg.Addr, err = netConn.ReadFrom(msg.Buffers[0]) + if err != nil { + log.Debugf("read err from turn server: %v", err) + return 0, err + } + log.Debugf("----msg address is: %s, size: %d", msg.Addr.String(), msg.N) + numMsgs = 1 + } else { + log.Debugf("----read from pc...") + numMsgs, err = pc.ReadBatch(*msgs, 0) + if err != nil { + return 0, err + } } } else { msg := &(*msgs)[0] @@ -86,7 +108,10 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC } addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation + ep := &wgConn.StdNetEndpoint{ + AddrPort: addrPort, + Conn: netConn, + } wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep) eps[i] = ep } diff --git a/iface/bind/receiver_creator.go b/iface/bind/receiver_creator.go index b205c4d5f..767b0d541 100644 --- a/iface/bind/receiver_creator.go +++ b/iface/bind/receiver_creator.go @@ -4,20 +4,35 @@ import ( "net" "sync" + log "github.com/sirupsen/logrus" "golang.org/x/net/ipv4" wgConn "golang.zx2c4.com/wireguard/conn" ) type receiverCreator struct { - iceBind *ICEBind + iceBind *ICEBind + relayConn net.PacketConn } -func newReceiverCreator(iceBind *ICEBind) receiverCreator { - return receiverCreator{ +func newReceiverCreator(iceBind *ICEBind) *receiverCreator { + return &receiverCreator{ iceBind: iceBind, } } -func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc { - return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn) +func (rc *receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc { + return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn, nil) +} + +func (rc *receiverCreator) CreateRelayReceiverFn(msgPool *sync.Pool) wgConn.ReceiveFunc { + if rc.relayConn == nil { + log.Debugf("-------rc.conn is nil") + return nil + } + return rc.iceBind.createIPv4ReceiverFn(msgPool, nil, nil, rc.relayConn) +} + +func (rc *receiverCreator) setTurnConn(relayConn interface{}) { + log.Debug("------ SET TURN CONN") + rc.relayConn = relayConn.(net.PacketConn) } diff --git a/iface/iface.go b/iface/iface.go index 3ae40ad4c..701805cf0 100644 --- a/iface/iface.go +++ b/iface/iface.go @@ -150,3 +150,10 @@ func (w *WGIface) GetDevice() *DeviceWrapper { func (w *WGIface) GetStats(peerKey string) (WGStats, error) { return w.configurer.getStats(peerKey) } + +func (w *WGIface) SetRelayConn(conn interface{}) { + w.mu.Lock() + defer w.mu.Unlock() + + w.tun.SetTurnConn(conn) +} diff --git a/iface/module_linux.go b/iface/module_linux.go index 11c0482d5..51bbb0f11 100644 --- a/iface/module_linux.go +++ b/iface/module_linux.go @@ -85,23 +85,27 @@ func tunModuleIsLoaded() bool { // WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only) func WireGuardModuleIsLoaded() bool { + return false - if os.Getenv(envDisableWireGuardKernel) == "true" { - log.Debugf("WireGuard kernel module disabled because the %s env is set to true", envDisableWireGuardKernel) - return false - } + /* + if os.Getenv(envDisableWireGuardKernel) == "true" { + log.Debugf("WireGuard kernel module disabled because the %s env is set to true", envDisableWireGuardKernel) + return false + } - if canCreateFakeWireGuardInterface() { - return true - } + if canCreateFakeWireGuardInterface() { + return true + } - loaded, err := tryToLoadModule("wireguard") - if err != nil { - log.Info(err) - return false - } + loaded, err := tryToLoadModule("wireguard") + if err != nil { + log.Info(err) + return false + } - return loaded + return loaded + + */ } func canCreateFakeWireGuardInterface() bool { diff --git a/iface/tun.go b/iface/tun.go index b3c0f9d80..9e49bb4e3 100644 --- a/iface/tun.go +++ b/iface/tun.go @@ -15,4 +15,5 @@ type wgTunDevice interface { DeviceName() string Close() error Wrapper() *DeviceWrapper // todo eliminate this function + SetTurnConn(conn interface{}) } diff --git a/iface/tun_kernel_linux.go b/iface/tun_kernel_linux.go index 12adcdf73..19675f472 100644 --- a/iface/tun_kernel_linux.go +++ b/iface/tun_kernel_linux.go @@ -31,6 +31,11 @@ type tunKernelDevice struct { udpMux *bind.UniversalUDPMuxDefault } +func (t *tunKernelDevice) SetTurnConn(interface{}) { + //TODO implement me + panic("implement me") +} + func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice { ctx, cancel := context.WithCancel(context.Background()) return &tunKernelDevice{ diff --git a/iface/tun_netstack.go b/iface/tun_netstack.go index e1d01ecc9..e9dfa9715 100644 --- a/iface/tun_netstack.go +++ b/iface/tun_netstack.go @@ -30,6 +30,11 @@ type tunNetstackDevice struct { configurer wgConfigurer } +func (t *tunNetstackDevice) SetTurnConn(interface{}) { + //TODO implement me + panic("implement me") +} + func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string) wgTunDevice { return &tunNetstackDevice{ name: name, diff --git a/iface/tun_usp_linux.go b/iface/tun_usp_linux.go index 9f0210228..8a23ef4f0 100644 --- a/iface/tun_usp_linux.go +++ b/iface/tun_usp_linux.go @@ -54,7 +54,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) { t.device = device.NewDevice( t.wrapper, t.iceBind, - device.NewLogger(device.LogLevelSilent, "[netbird] "), + device.NewLogger(device.LogLevelVerbose, "[netbird] "), ) err = t.assignAddr() @@ -70,6 +70,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) { t.configurer.close() return nil, err } + log.Debugf("configuration done") return t.configurer, nil } @@ -125,6 +126,14 @@ func (t *tunUSPDevice) Wrapper() *DeviceWrapper { return t.wrapper } +func (t *tunUSPDevice) SetTurnConn(conn interface{}) { + t.iceBind.SetTurnConn(conn) + err := t.device.BindUpdate() + if err != nil { + log.Errorf("failed to update bind: %v", err) + } +} + // assignAddr Adds IP address to the tunnel interface func (t *tunUSPDevice) assignAddr() error { link := newWGLink(t.name)