Apply new receiver functions

This commit is contained in:
Zoltán Papp 2024-04-16 16:01:25 +02:00
parent 28a9a2ef87
commit b5c4802bb9
12 changed files with 130 additions and 55 deletions

View File

@ -200,7 +200,6 @@ func NewEngineWithProbes(
networkSerial: 0, networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer, sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgProxyFactory: wgproxy.NewFactory(config.WgPort),
mgmProbe: mgmProbe, mgmProbe: mgmProbe,
signalProbe: signalProbe, signalProbe: signalProbe,
relayProbe: relayProbe, relayProbe: relayProbe,
@ -499,6 +498,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return fmt.Errorf("faile to open turn relay: %w", err) return fmt.Errorf("faile to open turn relay: %w", err)
} }
e.turnRelay = turnRelay e.turnRelay = turnRelay
e.wgInterface.SetRelayConn(e.turnRelay.RelayConn())
// todo update signal // todo update signal
} }
@ -620,7 +620,6 @@ func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error {
var newSTUNs []*stun.URI var newSTUNs []*stun.URI
log.Debugf("got STUNs update from Management Service, updating") log.Debugf("got STUNs update from Management Service, updating")
for _, s := range stuns { for _, s := range stuns {
log.Debugf("-----updated TURN: %s", s.Uri)
url, err := stun.ParseURI(s.Uri) url, err := stun.ParseURI(s.Uri)
if err != nil { if err != nil {
return err return err

View File

@ -345,21 +345,28 @@ func (conn *Conn) Open() error {
log.Warnf("error while updating the state of peer %s,err: %v", conn.config.Key, err) log.Warnf("error while updating the state of peer %s,err: %v", conn.config.Key, err)
} }
isControlling := conn.config.LocalKey > conn.config.Key isControlling := conn.config.LocalKey < conn.config.Key
if isControlling { if isControlling {
log.Debugf("---- use this peer's tunr connection")
err = conn.turnRelay.PunchHole(remoteOfferAnswer.RemoteAddr) err = conn.turnRelay.PunchHole(remoteOfferAnswer.RemoteAddr)
if err != nil { if err != nil {
log.Errorf("failed to punch hole: %v", err) log.Errorf("failed to punch hole: %v", err)
} }
} else { addr, ok := remoteOfferAnswer.RemoteAddr.(*net.UDPAddr)
/* if !ok {
remoteConn, err := net.Dial("udp", remoteOfferAnswer.RemoteAddr.String()) return fmt.Errorf("failed to cast addr to udp addr")
if err != nil {
log.Errorf("failed to dial remote peer %s: %v", conn.config.Key, err)
} }
addr.Port = remoteOfferAnswer.WgListenPort
*/ err := conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, addr, conn.config.WgConfig.PreSharedKey)
if err != nil {
if conn.wgProxy != nil {
_ = conn.wgProxy.CloseConn()
}
// todo close
return err
}
} else {
log.Debugf("---- use remote peer tunr connection")
addr, ok := remoteOfferAnswer.RelayedAddr.(*net.UDPAddr) addr, ok := remoteOfferAnswer.RelayedAddr.(*net.UDPAddr)
if !ok { if !ok {
return fmt.Errorf("failed to cast addr to udp addr") return fmt.Errorf("failed to cast addr to udp addr")
@ -414,13 +421,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
defer conn.mu.Unlock() defer conn.mu.Unlock()
var endpoint net.Addr var endpoint net.Addr
log.Debugf("setup relay connection") endpoint = remoteConn.RemoteAddr()
conn.wgProxy = conn.wgProxyFactory.GetProxy()
endpoint, err := conn.wgProxy.AddTurnConn(remoteConn)
if err != nil {
return nil, err
}
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.remoteEndpoint = endpointUdpAddr conn.remoteEndpoint = endpointUdpAddr
log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
@ -432,7 +433,7 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
} }
} }
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey) err := conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
if err != nil { if err != nil {
if conn.wgProxy != nil { if conn.wgProxy != nil {
_ = conn.wgProxy.CloseConn() _ = conn.wgProxy.CloseConn()

View File

@ -78,6 +78,15 @@ func (r *PermanentTurn) SrvRefAddr() net.Addr {
return r.srvReflexiveAddress return r.srvReflexiveAddress
} }
func (r *PermanentTurn) PunchHole(mappedAddr net.Addr) error {
_, err := r.relayConn.WriteTo([]byte("Hello"), mappedAddr)
return err
}
func (r *PermanentTurn) RelayConn() net.PacketConn {
return r.relayConn
}
func (r *PermanentTurn) discoverPublicIP() (*net.UDPAddr, error) { func (r *PermanentTurn) discoverPublicIP() (*net.UDPAddr, error) {
addr, err := r.turnClient.SendBindingRequest() addr, err := r.turnClient.SendBindingRequest()
if err != nil { if err != nil {
@ -119,11 +128,6 @@ func (r *PermanentTurn) listen() {
}() }()
} }
func (r *PermanentTurn) PunchHole(mappedAddr net.Addr) error {
_, err := r.relayConn.WriteTo([]byte("Hello"), mappedAddr)
return err
}
func toURL(uri *stun.URI) string { func toURL(uri *stun.URI) string {
return fmt.Sprintf("%s:%d", uri.Host, uri.Port) return fmt.Sprintf("%s:%d", uri.Host, uri.Port)
} }

2
go.mod
View File

@ -172,7 +172,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2023
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed replace golang.zx2c4.com/wireguard => /home/pzoli/go/src/github.com/netbirdio/wireguard-go
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

View File

@ -20,6 +20,8 @@ type ICEBind struct {
transportNet transport.Net transportNet transport.Net
udpMux *UniversalUDPMuxDefault udpMux *UniversalUDPMuxDefault
receiverCreator *receiverCreator
} }
func NewICEBind(transportNet transport.Net) *ICEBind { func NewICEBind(transportNet transport.Net) *ICEBind {
@ -28,6 +30,7 @@ func NewICEBind(transportNet transport.Net) *ICEBind {
} }
rc := newReceiverCreator(ib) rc := newReceiverCreator(ib)
ib.receiverCreator = rc
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc) ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
return ib return ib
@ -44,16 +47,22 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
return s.udpMux, nil return s.udpMux, nil
} }
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc { func (s *ICEBind) SetTurnConn(conn interface{}) {
s.receiverCreator.setTurnConn(conn)
}
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn, netConn net.PacketConn) wgConn.ReceiveFunc {
s.muUDPMux.Lock() s.muUDPMux.Lock()
defer s.muUDPMux.Unlock() defer s.muUDPMux.Unlock()
if conn != nil {
s.udpMux = NewUniversalUDPMuxDefault( s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{ UniversalUDPMuxParams{
UDPConn: conn, UDPConn: conn,
Net: s.transportNet, Net: s.transportNet,
}, },
) )
}
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message) msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
defer ipv4MsgsPool.Put(msgs) defer ipv4MsgsPool.Put(msgs)
@ -62,10 +71,23 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
} }
var numMsgs int var numMsgs int
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
if netConn != nil {
log.Debugf("----read from turn conn...")
msg := &(*msgs)[0]
msg.N, msg.Addr, err = netConn.ReadFrom(msg.Buffers[0])
if err != nil {
log.Debugf("read err from turn server: %v", err)
return 0, err
}
log.Debugf("----msg address is: %s, size: %d", msg.Addr.String(), msg.N)
numMsgs = 1
} else {
log.Debugf("----read from pc...")
numMsgs, err = pc.ReadBatch(*msgs, 0) numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil { if err != nil {
return 0, err return 0, err
} }
}
} else { } else {
msg := &(*msgs)[0] msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
@ -86,7 +108,10 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
} }
addrPort := msg.Addr.(*net.UDPAddr).AddrPort() addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation ep := &wgConn.StdNetEndpoint{
AddrPort: addrPort,
Conn: netConn,
}
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep) wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep eps[i] = ep
} }

View File

@ -4,20 +4,35 @@ import (
"net" "net"
"sync" "sync"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
) )
type receiverCreator struct { type receiverCreator struct {
iceBind *ICEBind iceBind *ICEBind
relayConn net.PacketConn
} }
func newReceiverCreator(iceBind *ICEBind) receiverCreator { func newReceiverCreator(iceBind *ICEBind) *receiverCreator {
return receiverCreator{ return &receiverCreator{
iceBind: iceBind, iceBind: iceBind,
} }
} }
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc { func (rc *receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn) return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn, nil)
}
func (rc *receiverCreator) CreateRelayReceiverFn(msgPool *sync.Pool) wgConn.ReceiveFunc {
if rc.relayConn == nil {
log.Debugf("-------rc.conn is nil")
return nil
}
return rc.iceBind.createIPv4ReceiverFn(msgPool, nil, nil, rc.relayConn)
}
func (rc *receiverCreator) setTurnConn(relayConn interface{}) {
log.Debug("------ SET TURN CONN")
rc.relayConn = relayConn.(net.PacketConn)
} }

View File

@ -150,3 +150,10 @@ func (w *WGIface) GetDevice() *DeviceWrapper {
func (w *WGIface) GetStats(peerKey string) (WGStats, error) { func (w *WGIface) GetStats(peerKey string) (WGStats, error) {
return w.configurer.getStats(peerKey) return w.configurer.getStats(peerKey)
} }
func (w *WGIface) SetRelayConn(conn interface{}) {
w.mu.Lock()
defer w.mu.Unlock()
w.tun.SetTurnConn(conn)
}

View File

@ -85,7 +85,9 @@ func tunModuleIsLoaded() bool {
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only) // WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireGuardModuleIsLoaded() bool { func WireGuardModuleIsLoaded() bool {
return false
/*
if os.Getenv(envDisableWireGuardKernel) == "true" { if os.Getenv(envDisableWireGuardKernel) == "true" {
log.Debugf("WireGuard kernel module disabled because the %s env is set to true", envDisableWireGuardKernel) log.Debugf("WireGuard kernel module disabled because the %s env is set to true", envDisableWireGuardKernel)
return false return false
@ -102,6 +104,8 @@ func WireGuardModuleIsLoaded() bool {
} }
return loaded return loaded
*/
} }
func canCreateFakeWireGuardInterface() bool { func canCreateFakeWireGuardInterface() bool {

View File

@ -15,4 +15,5 @@ type wgTunDevice interface {
DeviceName() string DeviceName() string
Close() error Close() error
Wrapper() *DeviceWrapper // todo eliminate this function Wrapper() *DeviceWrapper // todo eliminate this function
SetTurnConn(conn interface{})
} }

View File

@ -31,6 +31,11 @@ type tunKernelDevice struct {
udpMux *bind.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
} }
func (t *tunKernelDevice) SetTurnConn(interface{}) {
//TODO implement me
panic("implement me")
}
func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice { func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &tunKernelDevice{ return &tunKernelDevice{

View File

@ -30,6 +30,11 @@ type tunNetstackDevice struct {
configurer wgConfigurer configurer wgConfigurer
} }
func (t *tunNetstackDevice) SetTurnConn(interface{}) {
//TODO implement me
panic("implement me")
}
func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string) wgTunDevice { func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string) wgTunDevice {
return &tunNetstackDevice{ return &tunNetstackDevice{
name: name, name: name,

View File

@ -54,7 +54,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
t.device = device.NewDevice( t.device = device.NewDevice(
t.wrapper, t.wrapper,
t.iceBind, t.iceBind,
device.NewLogger(device.LogLevelSilent, "[netbird] "), device.NewLogger(device.LogLevelVerbose, "[netbird] "),
) )
err = t.assignAddr() err = t.assignAddr()
@ -70,6 +70,7 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
t.configurer.close() t.configurer.close()
return nil, err return nil, err
} }
log.Debugf("configuration done")
return t.configurer, nil return t.configurer, nil
} }
@ -125,6 +126,14 @@ func (t *tunUSPDevice) Wrapper() *DeviceWrapper {
return t.wrapper return t.wrapper
} }
func (t *tunUSPDevice) SetTurnConn(conn interface{}) {
t.iceBind.SetTurnConn(conn)
err := t.device.BindUpdate()
if err != nil {
log.Errorf("failed to update bind: %v", err)
}
}
// assignAddr Adds IP address to the tunnel interface // assignAddr Adds IP address to the tunnel interface
func (t *tunUSPDevice) assignAddr() error { func (t *tunUSPDevice) assignAddr() error {
link := newWGLink(t.name) link := newWGLink(t.name)