From 2829cce644b1e57782a2b614b8e0fa30ec037ec6 Mon Sep 17 00:00:00 2001 From: braginini Date: Tue, 6 Sep 2022 20:06:51 +0200 Subject: [PATCH] Implement ICEBind --- client/internal/engine.go | 68 +++------ client/internal/peer/conn.go | 12 +- go.mod | 4 +- iface/bind.go | 175 ++++++++++++++--------- iface/configuration.go | 1 - iface/iface.go | 2 +- iface/udp_mux.go | 262 +++++++++++++++++++++++++++++++++++ iface/udp_mux_universal.go | 246 ++++++++++++++++++++++++++++++++ iface/udp_muxed_conn.go | 246 ++++++++++++++++++++++++++++++++ 9 files changed, 893 insertions(+), 123 deletions(-) create mode 100644 iface/udp_mux.go create mode 100644 iface/udp_mux_universal.go create mode 100644 iface/udp_muxed_conn.go diff --git a/client/internal/engine.go b/client/internal/engine.go index d664d0395..46fff1da7 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -89,10 +89,7 @@ type Engine struct { wgInterface *iface.WGIface - udpMux ice.UDPMux - udpMuxSrflx ice.UniversalUDPMux - udpMuxConn *net.UDPConn - udpMuxConnSrflx *net.UDPConn + iceMux ice.UniversalUDPMux // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service networkSerial uint64 @@ -153,30 +150,6 @@ func (e *Engine) Stop() error { } } - if e.udpMux != nil { - if err := e.udpMux.Close(); err != nil { - log.Debugf("close udp mux: %v", err) - } - } - - if e.udpMuxSrflx != nil { - if err := e.udpMuxSrflx.Close(); err != nil { - log.Debugf("close server reflexive udp mux: %v", err) - } - } - - if e.udpMuxConn != nil { - if err := e.udpMuxConn.Close(); err != nil { - log.Debugf("close udp mux connection: %v", err) - } - } - - if e.udpMuxConnSrflx != nil { - if err := e.udpMuxConnSrflx.Close(); err != nil { - log.Debugf("close server reflexive udp mux connection: %v", err) - } - } - if !isNil(e.sshServer) { err := e.sshServer.Stop() if err != nil { @@ -209,7 +182,7 @@ func isWebRTC(p []byte, n int) bool { type sharedUDPConn struct { net.PacketConn - bind *iface.UserBind + bind *iface.ICEBind } func (s *sharedUDPConn) ReadFrom(buff []byte) (n int, addr net.Addr, err error) { @@ -226,7 +199,7 @@ func (s *sharedUDPConn) ReadFrom(buff []byte) (n int, addr net.Addr, err error) Port: int(e.Port()), Zone: e.Addr().Zone(), } - s.bind.OnData(bytes, a) + //s.bind.OnData(bytes, a) return 0, a, nil } } @@ -252,39 +225,32 @@ func (e *Engine) Start() error { return err } - e.udpMuxConn, err = net.ListenUDP("udp4", &net.UDPAddr{Port: e.config.UDPMuxPort}) - if err != nil { - log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxPort, err.Error()) - return err - } - s := &sharedUDPConn{PacketConn: e.udpMuxConn} - bind := iface.NewUserBind(s) - s.bind = bind - e.udpMuxConnSrflx, err = net.ListenUDP("udp4", &net.UDPAddr{Port: e.config.UDPMuxSrflxPort}) - if err != nil { - log.Errorf("failed listening on UDP port %d: [%s]", e.config.UDPMuxSrflxPort, err.Error()) - return err - } - e.udpMux = ice.NewUDPMuxDefault(ice.UDPMuxParams{UDPConn: s}) - - e.udpMuxSrflx = ice.NewUniversalUDPMuxDefault(ice.UniversalUDPMuxParams{UDPConn: e.udpMuxConnSrflx}) + bind := &iface.ICEBind{} err = e.wgInterface.CreateNew(bind) if err != nil { log.Errorf("failed creating tunnel interface %s: [%s]", wgIfaceName, err.Error()) return err } - log.Infof("shared sock ------------------> %s", s.LocalAddr().String()) - addrPort, err := netip.ParseAddrPort(s.LocalAddr().String()) + port, err := e.wgInterface.GetListenPort() if err != nil { return err } - err = e.wgInterface.Configure(myPrivateKey.String(), int(addrPort.Port())) + + err = e.wgInterface.Configure(myPrivateKey.String(), *port) if err != nil { log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIfaceName, err.Error()) return err } + iceMux, err := bind.GetICEMux() + if err != nil { + return err + } + e.iceMux = iceMux + + log.Infof("NetBird Engine started listening on WireGuard port %d", *port) + e.receiveSignalEvents() e.receiveManagementEvents() @@ -777,8 +743,8 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er StunTurn: stunTurn, InterfaceBlackList: e.config.IFaceBlackList, Timeout: timeout, - UDPMux: e.udpMux, - UDPMuxSrflx: e.udpMuxSrflx, + UDPMux: e.iceMux, + UDPMuxSrflx: e.iceMux, ProxyConfig: proxyConfig, LocalWgPort: e.config.WgPort, } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 338108bc0..f399f1717 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -146,12 +146,12 @@ func (conn *Conn) reCreateAgent() error { conn.agent, err = ice.NewAgent(&ice.AgentConfig{ MulticastDNSMode: ice.MulticastDNSModeDisabled, NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4}, - //Urls: conn.config.StunTurn, - CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost}, - FailedTimeout: &failedTimeout, - InterfaceFilter: interfaceFilter(conn.config.InterfaceBlackList), - UDPMux: conn.config.UDPMux, - UDPMuxSrflx: conn.config.UDPMuxSrflx, + Urls: conn.config.StunTurn, + CandidateTypes: []ice.CandidateType{ice.CandidateTypeServerReflexive, ice.CandidateTypeHost, ice.CandidateTypeRelay}, + FailedTimeout: &failedTimeout, + InterfaceFilter: interfaceFilter(conn.config.InterfaceBlackList), + UDPMux: conn.config.UDPMux, + UDPMuxSrflx: conn.config.UDPMuxSrflx, }) if err != nil { return err diff --git a/go.mod b/go.mod index 047818ef1..44675eebd 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,9 @@ require ( github.com/gliderlabs/ssh v0.3.4 github.com/magiconair/properties v1.8.5 github.com/patrickmn/go-cache v2.1.0+incompatible + github.com/pion/logging v0.2.2 github.com/pion/stun v0.3.5 + github.com/pion/transport v0.13.0 github.com/rs/xid v1.3.0 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/stretchr/testify v1.7.1 @@ -79,10 +81,8 @@ require ( github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c // indirect github.com/pegasus-kv/thrift v0.13.0 // indirect github.com/pion/dtls/v2 v2.1.2 // indirect - github.com/pion/logging v0.2.2 // indirect github.com/pion/mdns v0.0.5 // indirect github.com/pion/randutil v0.1.0 // indirect - github.com/pion/transport v0.13.0 // indirect github.com/pion/turn/v2 v2.0.7 // indirect github.com/pion/udp v0.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/iface/bind.go b/iface/bind.go index 5203c7efa..7d3037cb1 100644 --- a/iface/bind.go +++ b/iface/bind.go @@ -1,99 +1,157 @@ package iface import ( + "errors" + "fmt" + "github.com/pion/stun" + log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/conn" "net" "net/netip" "sync" + "syscall" ) -type UserEndpoint struct { - conn.StdNetEndpoint +type ICEBind struct { + sharedConn net.PacketConn + iceMux *UniversalUDPMuxDefault + + mu sync.Mutex // protects following fields } -type packet struct { - buff []byte - addr *net.UDPAddr +func (b *ICEBind) GetSharedConn() (net.PacketConn, error) { + b.mu.Lock() + defer b.mu.Unlock() + if b.sharedConn == nil { + return nil, fmt.Errorf("ICEBind has not been initialized yet") + } + + return b.sharedConn, nil } -type UserBind struct { - endpointsLock sync.RWMutex - endpoints map[netip.AddrPort]*UserEndpoint - sharedConn net.PacketConn +func (b *ICEBind) GetICEMux() (UniversalUDPMux, error) { + b.mu.Lock() + defer b.mu.Unlock() + if b.iceMux == nil { + return nil, fmt.Errorf("ICEBind has not been initialized yet") + } - Packets chan packet - closeSignal chan struct{} + return b.iceMux, nil } -func NewUserBind(sharedConn net.PacketConn) *UserBind { - return &UserBind{sharedConn: sharedConn} -} +func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { + b.mu.Lock() + defer b.mu.Unlock() -func (b *UserBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { + if b.sharedConn != nil { + return nil, 0, conn.ErrBindAlreadyOpen + } - b.Packets = make(chan packet, 1000) - b.closeSignal = make(chan struct{}) + port := int(uport) + ipv4Conn, port, err := listenNet("udp4", port) + if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + return nil, 0, err + } + b.sharedConn = ipv4Conn + b.iceMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.sharedConn}) - return []conn.ReceiveFunc{b.receive}, port, nil -} - -func (b *UserBind) receive(buff []byte) (int, conn.Endpoint, error) { - - /*n, endpoint, err := b.sharedConn.ReadFrom(buff) + portAddr, err := netip.ParseAddrPort(ipv4Conn.LocalAddr().String()) if err != nil { - return 0, nil, err - } - e, err := netip.ParseAddrPort(endpoint.String()) - if err != nil { - return 0, nil, err - } - return n, (*conn.StdNetEndpoint)(&net.UDPAddr{ - IP: e.addr().AsSlice(), - Port: int(e.Port()), - Zone: e.addr().Zone(), - }), err*/ - - select { - case <-b.closeSignal: - return 0, nil, net.ErrClosed - case pkt := <-b.Packets: - /*log.Infof("received packet %d from %s to copy to buffer %d", binary.Size(pkt.buff), pkt.addr.String(), - len(buff))*/ - return copy(buff, pkt.buff), (*conn.StdNetEndpoint)(pkt.addr), nil + return nil, 0, err } + return []conn.ReceiveFunc{b.makeReceiveIPv4(b.sharedConn)}, portAddr.Port(), nil } -func (b *UserBind) Close() error { - if b.closeSignal != nil { - select { - case <-b.closeSignal: - default: - close(b.closeSignal) +func listenNet(network string, port int) (*net.UDPConn, int, error) { + conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) + if err != nil { + return nil, 0, err + } + + // Retrieve port. + laddr := conn.LocalAddr() + uaddr, err := net.ResolveUDPAddr( + laddr.Network(), + laddr.String(), + ) + if err != nil { + return nil, 0, err + } + return conn, uaddr.Port, nil +} + +func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc { + return func(buff []byte) (int, conn.Endpoint, error) { + n, endpoint, err := c.ReadFrom(buff) + if err != nil { + return 0, nil, err } + e, err := netip.ParseAddrPort(endpoint.String()) + if err != nil { + return 0, nil, err + } + if !stun.IsMessage(buff[:n]) { + // WireGuard traffic + return n, (*conn.StdNetEndpoint)(&net.UDPAddr{ + IP: e.Addr().AsSlice(), + Port: int(e.Port()), + Zone: e.Addr().Zone(), + }), nil + } + + err = b.iceMux.HandlePacket(buff, n, endpoint) + if err != nil { + return 0, nil, err + } + if err != nil { + log.Warnf("failed to handle packet") + } + + // discard packets because they are STUN related + return 0, nil, nil //todo proper return } - return nil +} + +func (b *ICEBind) Close() error { + b.mu.Lock() + defer b.mu.Unlock() + + var err1, err2 error + if b.sharedConn != nil { + c := b.sharedConn + b.sharedConn = nil + err1 = c.Close() + } + + if b.iceMux != nil { + m := b.iceMux + b.iceMux = nil + err2 = m.Close() + } + + if err1 != nil { + return err1 + } + return err2 } // SetMark sets the mark for each packet sent through this Bind. // This mark is passed to the kernel as the socket option SO_MARK. -func (b *UserBind) SetMark(mark uint32) error { +func (b *ICEBind) SetMark(mark uint32) error { return nil } -func (b *UserBind) Send(buff []byte, endpoint conn.Endpoint) error { +func (b *ICEBind) Send(buff []byte, endpoint conn.Endpoint) error { nend, ok := endpoint.(*conn.StdNetEndpoint) if !ok { return conn.ErrWrongEndpointType } - - //log.Infof("sending packet %d from %s to %s", binary.Size(buff), b.sharedConn.LocalAddr().String(), (*net.UDPAddr)(nend).String()) - _, err := b.sharedConn.WriteTo(buff, (*net.UDPAddr)(nend)) return err } // ParseEndpoint creates a new endpoint from a string. -func (b *UserBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) { +func (b *ICEBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) { e, err := netip.ParseAddrPort(s) return (*conn.StdNetEndpoint)(&net.UDPAddr{ IP: e.Addr().AsSlice(), @@ -101,10 +159,3 @@ func (b *UserBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) { Zone: e.Addr().Zone(), }), err } - -func (b *UserBind) OnData(buff []byte, addr *net.UDPAddr) { - b.Packets <- packet{ - buff: buff, - addr: addr, - } -} diff --git a/iface/configuration.go b/iface/configuration.go index 9f49cf6ee..b862e43b0 100644 --- a/iface/configuration.go +++ b/iface/configuration.go @@ -45,7 +45,6 @@ func (w *WGIface) Configure(privateKey string, port int) error { PrivateKey: &key, ReplacePeers: true, FirewallMark: &fwmark, - ListenPort: &port, } err = w.configureDevice(config) diff --git a/iface/iface.go b/iface/iface.go index 5e4c5c7a6..e4b4dd013 100644 --- a/iface/iface.go +++ b/iface/iface.go @@ -21,7 +21,7 @@ type WGIface struct { Address WGAddress Interface NetInterface mu sync.Mutex - Bind *UserBind + Bind *ICEBind } // WGAddress Wireguard parsed address diff --git a/iface/udp_mux.go b/iface/udp_mux.go new file mode 100644 index 000000000..ba917aeb4 --- /dev/null +++ b/iface/udp_mux.go @@ -0,0 +1,262 @@ +package iface + +import ( + "fmt" + log "github.com/sirupsen/logrus" + "io" + "net" + "strings" + "sync" + + "github.com/pion/logging" + "github.com/pion/stun" +) + +const receiveMTU = 8192 + +// UDPMux allows multiple connections to go over a single UDP port +type UDPMux interface { + io.Closer + GetConn(ufrag string) (net.PacketConn, error) + RemoveConnByUfrag(ufrag string) +} + +// UDPMuxDefault is an implementation of the interface +type UDPMuxDefault struct { + params UDPMuxParams + + closedChan chan struct{} + closeOnce sync.Once + + // conns is a map of all udpMuxedConn indexed by ufrag|network|candidateType + conns map[string]*udpMuxedConn + + addressMapMu sync.RWMutex + addressMap map[string]*udpMuxedConn + + // buffer pool to recycle buffers for net.UDPAddr encodes/decodes + pool *sync.Pool + + mu sync.Mutex +} + +const maxAddrSize = 512 + +// UDPMuxParams are parameters for UDPMux. +type UDPMuxParams struct { + Logger logging.LeveledLogger + UDPConn net.PacketConn +} + +// NewUDPMuxDefault creates an implementation of UDPMux +func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { + if params.Logger == nil { + params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") + } + + return &UDPMuxDefault{ + addressMap: map[string]*udpMuxedConn{}, + params: params, + conns: make(map[string]*udpMuxedConn), + closedChan: make(chan struct{}, 1), + pool: &sync.Pool{ + New: func() interface{} { + // big enough buffer to fit both packet and address + return newBufferHolder(receiveMTU + maxAddrSize) + }, + }, + } +} + +func (m *UDPMuxDefault) HandlePacket(p []byte, n int, addr net.Addr) error { + + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + return fmt.Errorf("underlying PacketConn did not return a UDPAddr") + } + + // If we have already seen this address dispatch to the appropriate destination + m.addressMapMu.Lock() + destinationConn := m.addressMap[addr.String()] + m.addressMapMu.Unlock() + + // If we haven't seen this address before but is a STUN packet lookup by ufrag + if destinationConn == nil && stun.IsMessage(p[:n]) { + msg := &stun.Message{ + Raw: append([]byte{}, p[:n]...), + } + + if err := msg.Decode(); err != nil { + log.Warnf("Failed to handle decode ICE from %s: %v\n", addr.String(), err) + return err + } + + attr, stunAttrErr := msg.Get(stun.AttrUsername) + if stunAttrErr != nil { + log.Warnf("No Username attribute in STUN message from %s\n", addr.String()) + return stunAttrErr + } + + ufrag := strings.Split(string(attr), ":")[0] + + m.mu.Lock() + destinationConn = m.conns[ufrag] + m.mu.Unlock() + } + + if destinationConn == nil { + log.Tracef("dropping packet from %s, addr: %s", udpAddr.String(), addr.String()) + return nil + } + + if err := destinationConn.writePacket(p[:n], udpAddr); err != nil { + log.Errorf("could not write packet: %v", err) + } + + return nil +} + +// LocalAddr returns the listening address of this UDPMuxDefault +func (m *UDPMuxDefault) LocalAddr() net.Addr { + return m.params.UDPConn.LocalAddr() +} + +// GetConn returns a PacketConn given the connection's ufrag and network +// creates the connection if an existing one can't be found +func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.IsClosed() { + return nil, io.ErrClosedPipe + } + + if c, ok := m.conns[ufrag]; ok { + return c, nil + } + + c := m.createMuxedConn(ufrag) + go func() { + <-c.CloseChannel() + m.removeConn(ufrag) + }() + m.conns[ufrag] = c + return c, nil +} + +// RemoveConnByUfrag stops and removes the muxed packet connection +func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { + m.mu.Lock() + removedConns := make([]*udpMuxedConn, 0) + for key := range m.conns { + if key != ufrag { + continue + } + + c := m.conns[key] + delete(m.conns, key) + if c != nil { + removedConns = append(removedConns, c) + } + } + // keep lock section small to avoid deadlock with conn lock + m.mu.Unlock() + + m.addressMapMu.Lock() + defer m.addressMapMu.Unlock() + + for _, c := range removedConns { + addresses := c.getAddresses() + for _, addr := range addresses { + delete(m.addressMap, addr) + } + } +} + +// IsClosed returns true if the mux had been closed +func (m *UDPMuxDefault) IsClosed() bool { + select { + case <-m.closedChan: + return true + default: + return false + } +} + +// Close the mux, no further connections could be created +func (m *UDPMuxDefault) Close() error { + var err error + m.closeOnce.Do(func() { + m.mu.Lock() + defer m.mu.Unlock() + + for _, c := range m.conns { + _ = c.Close() + } + m.conns = make(map[string]*udpMuxedConn) + close(m.closedChan) + }) + return err +} + +func (m *UDPMuxDefault) removeConn(key string) { + m.mu.Lock() + c := m.conns[key] + delete(m.conns, key) + // keep lock section small to avoid deadlock with conn lock + m.mu.Unlock() + + if c == nil { + return + } + + m.addressMapMu.Lock() + defer m.addressMapMu.Unlock() + + addresses := c.getAddresses() + for _, addr := range addresses { + delete(m.addressMap, addr) + } +} + +func (m *UDPMuxDefault) writeTo(buf []byte, raddr net.Addr) (n int, err error) { + return m.params.UDPConn.WriteTo(buf, raddr) +} + +func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) { + if m.IsClosed() { + return + } + + m.addressMapMu.Lock() + defer m.addressMapMu.Unlock() + + existing, ok := m.addressMap[addr] + if ok { + existing.removeAddress(addr) + } + m.addressMap[addr] = conn + + m.params.Logger.Debugf("Registered %s for %s", addr, conn.params.Key) +} + +func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn { + c := newUDPMuxedConn(&udpMuxedConnParams{ + Mux: m, + Key: key, + AddrPool: m.pool, + LocalAddr: m.LocalAddr(), + Logger: m.params.Logger, + }) + return c +} + +type bufferHolder struct { + buffer []byte +} + +func newBufferHolder(size int) *bufferHolder { + return &bufferHolder{ + buffer: make([]byte, size), + } +} diff --git a/iface/udp_mux_universal.go b/iface/udp_mux_universal.go new file mode 100644 index 000000000..e2ff55f68 --- /dev/null +++ b/iface/udp_mux_universal.go @@ -0,0 +1,246 @@ +package iface + +import ( + "errors" + "fmt" + log "github.com/sirupsen/logrus" + "net" + "time" + + "github.com/pion/logging" + "github.com/pion/stun" +) + +// UniversalUDPMux allows multiple connections to go over a single UDP port for +// host, server reflexive and relayed candidates. +// Actual connection muxing is happening in the UDPMux. +type UniversalUDPMux interface { + UDPMux + GetXORMappedAddr(stunAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) + GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) + GetConnForURL(ufrag string, url string) (net.PacketConn, error) +} + +// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn overriding ReadFrom. +// It the passes packets to the UDPMux that does the actual connection muxing. +type UniversalUDPMuxDefault struct { + *UDPMuxDefault + params UniversalUDPMuxParams + + // since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents + // stun.XORMappedAddress indexed by the STUN server addr + xorMappedMap map[string]*xorMapped +} + +// UniversalUDPMuxParams are parameters for UniversalUDPMux server reflexive. +type UniversalUDPMuxParams struct { + Logger logging.LeveledLogger + UDPConn net.PacketConn + XORMappedAddrCacheTTL time.Duration +} + +// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux +func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault { + if params.Logger == nil { + params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") + } + if params.XORMappedAddrCacheTTL == 0 { + params.XORMappedAddrCacheTTL = time.Second * 25 + } + + m := &UniversalUDPMuxDefault{ + params: params, + xorMappedMap: make(map[string]*xorMapped), + } + + // embed UDPMux + udpMuxParams := UDPMuxParams{ + Logger: params.Logger, + UDPConn: m.params.UDPConn, + } + m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams) + + return m +} + +// GetRelayedAddr creates relayed connection to the given TURN service and returns the relayed addr. +// Not implemented yet. +func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) { + return nil, errors.New("not implemented yet") +} + +// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers +// and return a unique connection per server. +func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string) (net.PacketConn, error) { + return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url)) +} + +func (m *UniversalUDPMuxDefault) HandlePacket(p []byte, n int, addr net.Addr) error { + if stun.IsMessage(p[:n]) { + msg := &stun.Message{ + Raw: append([]byte{}, p[:n]...), + } + + if err := msg.Decode(); err != nil { + log.Warnf("Failed to handle decode ICE from %s: %v\n", addr.String(), err) + // todo proper error + return nil + } + + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + // message about this err will be logged in the UDPMux + return nil + } + + if m.isXORMappedResponse(msg, udpAddr.String()) { + err := m.handleXORMappedResponse(udpAddr, msg) + if err != nil { + log.Debugf("%w: %v", errors.New("failed to get XOR-MAPPED-ADDRESS response"), err) + return nil + } + return nil + } + } + return m.UDPMuxDefault.HandlePacket(p, n, addr) +} + +// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server. +func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool { + m.mu.Lock() + defer m.mu.Unlock() + // check first if it is a STUN server address because remote peer can also send similar messages but as a BindingSuccess + _, ok := m.xorMappedMap[stunAddr] + _, err := msg.Get(stun.AttrXORMappedAddress) + return err == nil && ok +} + +// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute +// and set the mapped address for the server +func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error { + m.mu.Lock() + defer m.mu.Unlock() + + mappedAddr, ok := m.xorMappedMap[stunAddr.String()] + if !ok { + return errors.New("no address mapping") + } + + var addr stun.XORMappedAddress + if err := addr.GetFrom(msg); err != nil { + return err + } + + m.xorMappedMap[stunAddr.String()] = mappedAddr + mappedAddr.SetAddr(&addr) + + return nil +} + +// GetXORMappedAddr returns *stun.XORMappedAddress if already present for a given STUN server. +// Makes a STUN binding request to discover mapped address otherwise. +// Blocks until the stun.XORMappedAddress has been discovered or deadline. +// Method is safe for concurrent use. +func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) { + m.mu.Lock() + mappedAddr, ok := m.xorMappedMap[serverAddr.String()] + // if we already have a mapping for this STUN server (address already received) + // and if it is not too old we return it without making a new request to STUN server + if ok { + if mappedAddr.expired() { + mappedAddr.closeWaiters() + delete(m.xorMappedMap, serverAddr.String()) + ok = false + } else if mappedAddr.pending() { + ok = false + } + } + m.mu.Unlock() + if ok { + return mappedAddr.addr, nil + } + + // otherwise, make a STUN request to discover the address + // or wait for already sent request to complete + waitAddrReceived, err := m.sendStun(serverAddr) + if err != nil { + return nil, errors.New("failed to send STUN packet") + } + + // block until response was handled by the connWorker routine and XORMappedAddress was updated + select { + case <-waitAddrReceived: + // when channel closed, addr was obtained + m.mu.Lock() + mappedAddr := *m.xorMappedMap[serverAddr.String()] + m.mu.Unlock() + if mappedAddr.addr == nil { + return nil, errors.New("no address mapping") + } + return mappedAddr.addr, nil + case <-time.After(deadline): + return nil, errors.New("timeout while waiting for XORMappedAddr") + } +} + +// sendStun sends a STUN request via UDP conn. +// +// The returned channel is closed when the STUN response has been received. +// Method is safe for concurrent use. +func (m *UniversalUDPMuxDefault) sendStun(serverAddr net.Addr) (chan struct{}, error) { + m.mu.Lock() + defer m.mu.Unlock() + + // if record present in the map, we already sent a STUN request, + // just wait when waitAddrReceived will be closed + addrMap, ok := m.xorMappedMap[serverAddr.String()] + if !ok { + addrMap = &xorMapped{ + expiresAt: time.Now().Add(m.params.XORMappedAddrCacheTTL), + waitAddrReceived: make(chan struct{}), + } + m.xorMappedMap[serverAddr.String()] = addrMap + } + + req, err := stun.Build(stun.BindingRequest, stun.TransactionID) + if err != nil { + return nil, err + } + + if _, err = m.params.UDPConn.WriteTo(req.Raw, serverAddr); err != nil { + return nil, err + } + + return addrMap.waitAddrReceived, nil +} + +type xorMapped struct { + addr *stun.XORMappedAddress + waitAddrReceived chan struct{} + expiresAt time.Time +} + +func (a *xorMapped) closeWaiters() { + select { + case <-a.waitAddrReceived: + // notify was close, ok, that means we received duplicate response + // just exit + break + default: + // notify tha twe have a new addr + close(a.waitAddrReceived) + } +} + +func (a *xorMapped) pending() bool { + return a.addr == nil +} + +func (a *xorMapped) expired() bool { + return a.expiresAt.Before(time.Now()) +} + +func (a *xorMapped) SetAddr(addr *stun.XORMappedAddress) { + a.addr = addr + a.closeWaiters() +} diff --git a/iface/udp_muxed_conn.go b/iface/udp_muxed_conn.go new file mode 100644 index 000000000..97491e65f --- /dev/null +++ b/iface/udp_muxed_conn.go @@ -0,0 +1,246 @@ +package iface + +import ( + "encoding/binary" + "io" + "net" + "sync" + "time" + + "github.com/pion/logging" + "github.com/pion/transport/packetio" +) + +type udpMuxedConnParams struct { + Mux *UDPMuxDefault + AddrPool *sync.Pool + Key string + LocalAddr net.Addr + Logger logging.LeveledLogger +} + +// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag +type udpMuxedConn struct { + params *udpMuxedConnParams + // remote addresses that we have sent to on this conn + addresses []string + + // channel holding incoming packets + buffer *packetio.Buffer + closedChan chan struct{} + closeOnce sync.Once + mu sync.Mutex +} + +func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn { + p := &udpMuxedConn{ + params: params, + buffer: packetio.NewBuffer(), + closedChan: make(chan struct{}), + } + + return p +} + +func (c *udpMuxedConn) ReadFrom(b []byte) (n int, raddr net.Addr, err error) { + buf := c.params.AddrPool.Get().(*bufferHolder) + defer c.params.AddrPool.Put(buf) + + // read address + total, err := c.buffer.Read(buf.buffer) + if err != nil { + return 0, nil, err + } + + dataLen := int(binary.LittleEndian.Uint16(buf.buffer[:2])) + if dataLen > total || dataLen > len(b) { + return 0, nil, io.ErrShortBuffer + } + + // read data and then address + offset := 2 + copy(b, buf.buffer[offset:offset+dataLen]) + offset += dataLen + + // read address len & decode address + addrLen := int(binary.LittleEndian.Uint16(buf.buffer[offset : offset+2])) + offset += 2 + + if raddr, err = decodeUDPAddr(buf.buffer[offset : offset+addrLen]); err != nil { + return 0, nil, err + } + + return dataLen, raddr, nil +} + +func (c *udpMuxedConn) WriteTo(buf []byte, raddr net.Addr) (n int, err error) { + if c.isClosed() { + return 0, io.ErrClosedPipe + } + // each time we write to a new address, we'll register it with the mux + addr := raddr.String() + if !c.containsAddress(addr) { + c.addAddress(addr) + } + + return c.params.Mux.writeTo(buf, raddr) +} + +func (c *udpMuxedConn) LocalAddr() net.Addr { + return c.params.LocalAddr +} + +func (c *udpMuxedConn) SetDeadline(tm time.Time) error { + return nil +} + +func (c *udpMuxedConn) SetReadDeadline(tm time.Time) error { + return nil +} + +func (c *udpMuxedConn) SetWriteDeadline(tm time.Time) error { + return nil +} + +func (c *udpMuxedConn) CloseChannel() <-chan struct{} { + return c.closedChan +} + +func (c *udpMuxedConn) Close() error { + var err error + c.closeOnce.Do(func() { + err = c.buffer.Close() + close(c.closedChan) + }) + c.mu.Lock() + defer c.mu.Unlock() + c.addresses = nil + return err +} + +func (c *udpMuxedConn) isClosed() bool { + select { + case <-c.closedChan: + return true + default: + return false + } +} + +func (c *udpMuxedConn) getAddresses() []string { + c.mu.Lock() + defer c.mu.Unlock() + addresses := make([]string, len(c.addresses)) + copy(addresses, c.addresses) + return addresses +} + +func (c *udpMuxedConn) addAddress(addr string) { + c.mu.Lock() + c.addresses = append(c.addresses, addr) + c.mu.Unlock() + + // map it on mux + c.params.Mux.registerConnForAddress(c, addr) +} + +func (c *udpMuxedConn) removeAddress(addr string) { + c.mu.Lock() + defer c.mu.Unlock() + + newAddresses := make([]string, 0, len(c.addresses)) + for _, a := range c.addresses { + if a != addr { + newAddresses = append(newAddresses, a) + } + } + + c.addresses = newAddresses +} + +func (c *udpMuxedConn) containsAddress(addr string) bool { + c.mu.Lock() + defer c.mu.Unlock() + for _, a := range c.addresses { + if addr == a { + return true + } + } + return false +} + +func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error { + // write two packets, address and data + buf := c.params.AddrPool.Get().(*bufferHolder) + defer c.params.AddrPool.Put(buf) + + // format of buffer | data len | data bytes | addr len | addr bytes | + if len(buf.buffer) < len(data)+maxAddrSize { + return io.ErrShortBuffer + } + // data len + binary.LittleEndian.PutUint16(buf.buffer, uint16(len(data))) + offset := 2 + + // data + copy(buf.buffer[offset:], data) + offset += len(data) + + // write address first, leaving room for its length + n, err := encodeUDPAddr(addr, buf.buffer[offset+2:]) + if err != nil { + return nil + } + total := offset + n + 2 + + // address len + binary.LittleEndian.PutUint16(buf.buffer[offset:], uint16(n)) + + if _, err := c.buffer.Write(buf.buffer[:total]); err != nil { + return err + } + return nil +} + +func encodeUDPAddr(addr *net.UDPAddr, buf []byte) (int, error) { + ipdata, err := addr.IP.MarshalText() + if err != nil { + return 0, err + } + total := 2 + len(ipdata) + 2 + len(addr.Zone) + if total > len(buf) { + return 0, io.ErrShortBuffer + } + + binary.LittleEndian.PutUint16(buf, uint16(len(ipdata))) + offset := 2 + n := copy(buf[offset:], ipdata) + offset += n + binary.LittleEndian.PutUint16(buf[offset:], uint16(addr.Port)) + offset += 2 + copy(buf[offset:], addr.Zone) + return total, nil +} + +func decodeUDPAddr(buf []byte) (*net.UDPAddr, error) { + addr := net.UDPAddr{} + + offset := 0 + ipLen := int(binary.LittleEndian.Uint16(buf[:2])) + offset += 2 + // basic bounds checking + if ipLen+offset > len(buf) { + return nil, io.ErrShortBuffer + } + if err := addr.IP.UnmarshalText(buf[offset : offset+ipLen]); err != nil { + return nil, err + } + offset += ipLen + addr.Port = int(binary.LittleEndian.Uint16(buf[offset : offset+2])) + offset += 2 + zone := make([]byte, len(buf[offset:])) + copy(zone, buf[offset:]) + addr.Zone = string(zone) + + return &addr, nil +}