Single Mux

This commit is contained in:
braginini 2022-09-07 18:39:58 +02:00
parent 48b7c6ec3c
commit eaf985624d
3 changed files with 99 additions and 77 deletions

View File

@ -210,11 +210,11 @@ func (e *Engine) Start() error {
}
e.iceMux = iceMux
iceHostMux, err := bind.GetICEHostMux()
/*iceHostMux, err := bind.GetICEHostMux()
if err != nil {
return err
}
e.iceHostMux = iceHostMux
}*/
e.iceHostMux = iceMux
log.Infof("NetBird Engine started listening on WireGuard port %d", *port)

View File

@ -18,10 +18,9 @@ type BindMux interface {
}
type ICEBind struct {
sharedConn net.PacketConn
sharedConnHost net.PacketConn
iceSrflxMux *UniversalUDPMuxDefault
iceHostMux *UDPMuxDefault
sharedConn net.PacketConn
udpMux *UniversalUDPMuxDefault
iceHostMux *UDPMuxDefault
endpointMap map[string]net.PacketConn
@ -31,11 +30,11 @@ type ICEBind struct {
func (b *ICEBind) GetICEMux() (UniversalUDPMux, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.iceSrflxMux == nil {
if b.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return b.iceSrflxMux, nil
return b.udpMux, nil
}
func (b *ICEBind) GetICEHostMux() (UDPMux, error) {
@ -55,9 +54,6 @@ func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
if b.sharedConn != nil {
return nil, 0, conn.ErrBindAlreadyOpen
}
if b.sharedConnHost != nil {
return nil, 0, conn.ErrBindAlreadyOpen
}
b.endpointMap = make(map[string]net.PacketConn)
@ -66,24 +62,20 @@ func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
ipv4ConnHost, port, err := listenNet("udp4", 0)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
b.sharedConn = ipv4Conn
b.sharedConnHost = ipv4ConnHost
b.iceSrflxMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.sharedConn})
b.iceHostMux = NewUDPMuxDefault(UDPMuxParams{UDPConn: b.sharedConnHost})
b.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.sharedConn})
portAddr, err := netip.ParseAddrPort(ipv4Conn.LocalAddr().String())
portAddr1, err := netip.ParseAddrPort(ipv4Conn.LocalAddr().String())
if err != nil {
return nil, 0, err
}
log.Infof("opened ICEBind on %s", ipv4Conn.LocalAddr().String())
return []conn.ReceiveFunc{
b.makeReceiveIPv4(b.sharedConn, b.iceSrflxMux),
b.makeReceiveIPv4(b.sharedConnHost, b.iceHostMux),
b.makeReceiveIPv4(b.sharedConn),
},
portAddr.Port(), nil
portAddr1.Port(), nil
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
@ -104,7 +96,7 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
return conn, uaddr.Port, nil
}
func (b *ICEBind) makeReceiveIPv4(c net.PacketConn, bindMux BindMux) conn.ReceiveFunc {
func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc {
return func(buff []byte) (int, conn.Endpoint, error) {
n, endpoint, err := c.ReadFrom(buff)
if err != nil {
@ -122,15 +114,37 @@ func (b *ICEBind) makeReceiveIPv4(c net.PacketConn, bindMux BindMux) conn.Receiv
Zone: e.Addr().Zone(),
}), nil
}
b.mu.Lock()
/* msg := &stun.Message{
Raw: append([]byte{}, buff[:n]...),
}
if err := msg.Decode(); err != nil {
return 0, nil, err
}
strAttrs := []string{}
for _, attribute := range msg.Attributes {
strAttrs = append(strAttrs, attribute.String())
}
xorMapped := "EMPTY"
_, err = msg.Get(stun.AttrXORMappedAddress)
if err == nil {
var addr stun.XORMappedAddress
if err := addr.GetFrom(msg); err == nil {
xorMapped = addr.String()
}
}
log.Printf("endpoint %s XORMAPPED %s mux type %s msg type %s, attributes %s", endpoint.String(), xorMapped, bindMux.Type(), msg.Type.String(), strings.Join(strAttrs[:], ";"))
*/
if _, ok := b.endpointMap[e.String()]; !ok {
b.endpointMap[e.String()] = c
log.Infof("added %s endpoint %s", bindMux.Type(), e.String())
log.Infof("added endpoint %s", e.String())
}
b.mu.Unlock()
err = bindMux.HandlePacket(buff, n, endpoint)
err = b.udpMux.HandlePacket(buff, n, endpoint)
if err != nil {
return 0, nil, err
}
@ -147,43 +161,24 @@ func (b *ICEBind) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
var err1, err2, err3, err4 error
var err1, err2 error
if b.sharedConn != nil {
c := b.sharedConn
b.sharedConn = nil
err1 = c.Close()
}
if b.sharedConnHost != nil {
c := b.sharedConnHost
b.sharedConnHost = nil
err2 = c.Close()
}
if b.iceSrflxMux != nil {
m := b.iceSrflxMux
b.iceSrflxMux = nil
err3 = m.Close()
if b.udpMux != nil {
m := b.udpMux
b.udpMux = nil
err2 = m.Close()
}
if b.iceHostMux != nil {
m := b.iceHostMux
b.iceHostMux = nil
err4 = m.Close()
}
//todo close iceSrflxMux
if err1 != nil {
return err1
}
if err2 != nil {
return err2
}
if err3 != nil {
return err3
}
return err4
return err2
}
// SetMark sets the mark for each packet sent through this Bind.

View File

@ -32,7 +32,7 @@ type UDPMuxDefault struct {
conns map[string]*udpMuxedConn
addressMapMu sync.RWMutex
addressMap map[string]*udpMuxedConn
addressMap map[string][]*udpMuxedConn
// buffer pool to recycle buffers for net.UDPAddr encodes/decodes
pool *sync.Pool
@ -55,7 +55,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
}
return &UDPMuxDefault{
addressMap: map[string]*udpMuxedConn{},
addressMap: map[string][]*udpMuxedConn{},
params: params,
conns: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
@ -81,11 +81,19 @@ func (m *UDPMuxDefault) HandlePacket(p []byte, n int, addr net.Addr) error {
// If we have already seen this address dispatch to the appropriate destination
m.addressMapMu.Lock()
destinationConn := m.addressMap[addr.String()]
var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok {
for _, conn := range storedConns {
destinationConnList = append(destinationConnList, conn)
}
}
m.addressMapMu.Unlock()
// If we haven't seen this address before but is a STUN packet lookup by ufrag
if destinationConn == nil && stun.IsMessage(p[:20]) {
if stun.IsMessage(p[:20]) {
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
// However, we can take a username attribute from the STUN message which contains ufrag.
// We can use ufrag to identify the destination conn to route packet to.
msg := &stun.Message{
Raw: append([]byte{}, p[:n]...),
}
@ -96,25 +104,32 @@ func (m *UDPMuxDefault) HandlePacket(p []byte, n int, addr net.Addr) error {
}
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr != nil {
log.Warnf("No Username attribute in STUN message from %s\n", addr.String())
return stunAttrErr
if stunAttrErr == nil {
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
if destinationConn, ok := m.conns[ufrag]; ok {
exists := false
for _, conn := range destinationConnList {
if conn.params.Key == destinationConn.params.Key {
exists = true
break
}
}
if !exists {
destinationConnList = append(destinationConnList, destinationConn)
}
}
m.mu.Unlock()
} else {
//log.Warnf("No Username attribute in STUN message from %s\n", addr.String())
}
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
destinationConn = m.conns[ufrag]
m.mu.Unlock()
}
if destinationConn == nil {
log.Tracef("dropping packet from %s, addr: %s", udpAddr.String(), addr.String())
return nil
}
if err := destinationConn.writePacket(p[:n], udpAddr); err != nil {
log.Errorf("could not write packet: %v", err)
for _, conn := range destinationConnList {
if err := conn.writePacket(p[:n], udpAddr); err != nil {
log.Errorf("could not write packet: %v", err)
}
}
return nil
@ -131,6 +146,8 @@ func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) {
m.mu.Lock()
defer m.mu.Unlock()
log.Debugf("ICE %s: getting muxed connection for %s", m.Type(), ufrag)
if m.IsClosed() {
return nil, io.ErrClosedPipe
}
@ -219,7 +236,15 @@ func (m *UDPMuxDefault) removeConn(key string) {
addresses := c.getAddresses()
for _, addr := range addresses {
delete(m.addressMap, addr)
if connList, ok := m.addressMap[addr]; ok {
var newList []*udpMuxedConn
for _, conn := range connList {
if conn.params.Key != key {
newList = append(newList, conn)
}
}
m.addressMap[addr] = newList
}
}
}
@ -236,12 +261,13 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
defer m.addressMapMu.Unlock()
existing, ok := m.addressMap[addr]
if ok {
existing.removeAddress(addr)
if !ok {
existing = []*udpMuxedConn{}
}
m.addressMap[addr] = conn
existing = append(existing, conn)
m.addressMap[addr] = existing
m.params.Logger.Debugf("Registered %s for %s", addr, conn.params.Key)
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
}
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
@ -252,6 +278,7 @@ func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
})
log.Debugf("ICE: created muxed connection %s for %s", c.LocalAddr().String(), key)
return c
}