mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-12 18:00:49 +01:00
Single Mux
This commit is contained in:
parent
48b7c6ec3c
commit
eaf985624d
@ -210,11 +210,11 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.iceMux = iceMux
|
||||
|
||||
iceHostMux, err := bind.GetICEHostMux()
|
||||
/*iceHostMux, err := bind.GetICEHostMux()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
e.iceHostMux = iceHostMux
|
||||
}*/
|
||||
e.iceHostMux = iceMux
|
||||
|
||||
log.Infof("NetBird Engine started listening on WireGuard port %d", *port)
|
||||
|
||||
|
@ -18,10 +18,9 @@ type BindMux interface {
|
||||
}
|
||||
|
||||
type ICEBind struct {
|
||||
sharedConn net.PacketConn
|
||||
sharedConnHost net.PacketConn
|
||||
iceSrflxMux *UniversalUDPMuxDefault
|
||||
iceHostMux *UDPMuxDefault
|
||||
sharedConn net.PacketConn
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
iceHostMux *UDPMuxDefault
|
||||
|
||||
endpointMap map[string]net.PacketConn
|
||||
|
||||
@ -31,11 +30,11 @@ type ICEBind struct {
|
||||
func (b *ICEBind) GetICEMux() (UniversalUDPMux, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.iceSrflxMux == nil {
|
||||
if b.udpMux == nil {
|
||||
return nil, fmt.Errorf("ICEBind has not been initialized yet")
|
||||
}
|
||||
|
||||
return b.iceSrflxMux, nil
|
||||
return b.udpMux, nil
|
||||
}
|
||||
|
||||
func (b *ICEBind) GetICEHostMux() (UDPMux, error) {
|
||||
@ -55,9 +54,6 @@ func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
if b.sharedConn != nil {
|
||||
return nil, 0, conn.ErrBindAlreadyOpen
|
||||
}
|
||||
if b.sharedConnHost != nil {
|
||||
return nil, 0, conn.ErrBindAlreadyOpen
|
||||
}
|
||||
|
||||
b.endpointMap = make(map[string]net.PacketConn)
|
||||
|
||||
@ -66,24 +62,20 @@ func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return nil, 0, err
|
||||
}
|
||||
ipv4ConnHost, port, err := listenNet("udp4", 0)
|
||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return nil, 0, err
|
||||
}
|
||||
b.sharedConn = ipv4Conn
|
||||
b.sharedConnHost = ipv4ConnHost
|
||||
b.iceSrflxMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.sharedConn})
|
||||
b.iceHostMux = NewUDPMuxDefault(UDPMuxParams{UDPConn: b.sharedConnHost})
|
||||
b.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.sharedConn})
|
||||
|
||||
portAddr, err := netip.ParseAddrPort(ipv4Conn.LocalAddr().String())
|
||||
portAddr1, err := netip.ParseAddrPort(ipv4Conn.LocalAddr().String())
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
log.Infof("opened ICEBind on %s", ipv4Conn.LocalAddr().String())
|
||||
|
||||
return []conn.ReceiveFunc{
|
||||
b.makeReceiveIPv4(b.sharedConn, b.iceSrflxMux),
|
||||
b.makeReceiveIPv4(b.sharedConnHost, b.iceHostMux),
|
||||
b.makeReceiveIPv4(b.sharedConn),
|
||||
},
|
||||
portAddr.Port(), nil
|
||||
portAddr1.Port(), nil
|
||||
}
|
||||
|
||||
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
||||
@ -104,7 +96,7 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
||||
return conn, uaddr.Port, nil
|
||||
}
|
||||
|
||||
func (b *ICEBind) makeReceiveIPv4(c net.PacketConn, bindMux BindMux) conn.ReceiveFunc {
|
||||
func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc {
|
||||
return func(buff []byte) (int, conn.Endpoint, error) {
|
||||
n, endpoint, err := c.ReadFrom(buff)
|
||||
if err != nil {
|
||||
@ -122,15 +114,37 @@ func (b *ICEBind) makeReceiveIPv4(c net.PacketConn, bindMux BindMux) conn.Receiv
|
||||
Zone: e.Addr().Zone(),
|
||||
}), nil
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
|
||||
/* msg := &stun.Message{
|
||||
Raw: append([]byte{}, buff[:n]...),
|
||||
}
|
||||
if err := msg.Decode(); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
strAttrs := []string{}
|
||||
for _, attribute := range msg.Attributes {
|
||||
strAttrs = append(strAttrs, attribute.String())
|
||||
}
|
||||
|
||||
xorMapped := "EMPTY"
|
||||
_, err = msg.Get(stun.AttrXORMappedAddress)
|
||||
if err == nil {
|
||||
var addr stun.XORMappedAddress
|
||||
if err := addr.GetFrom(msg); err == nil {
|
||||
xorMapped = addr.String()
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("endpoint %s XORMAPPED %s mux type %s msg type %s, attributes %s", endpoint.String(), xorMapped, bindMux.Type(), msg.Type.String(), strings.Join(strAttrs[:], ";"))
|
||||
*/
|
||||
if _, ok := b.endpointMap[e.String()]; !ok {
|
||||
b.endpointMap[e.String()] = c
|
||||
log.Infof("added %s endpoint %s", bindMux.Type(), e.String())
|
||||
log.Infof("added endpoint %s", e.String())
|
||||
}
|
||||
b.mu.Unlock()
|
||||
|
||||
err = bindMux.HandlePacket(buff, n, endpoint)
|
||||
err = b.udpMux.HandlePacket(buff, n, endpoint)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
@ -147,43 +161,24 @@ func (b *ICEBind) Close() error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
var err1, err2, err3, err4 error
|
||||
var err1, err2 error
|
||||
if b.sharedConn != nil {
|
||||
c := b.sharedConn
|
||||
b.sharedConn = nil
|
||||
err1 = c.Close()
|
||||
}
|
||||
if b.sharedConnHost != nil {
|
||||
c := b.sharedConnHost
|
||||
b.sharedConnHost = nil
|
||||
err2 = c.Close()
|
||||
}
|
||||
|
||||
if b.iceSrflxMux != nil {
|
||||
m := b.iceSrflxMux
|
||||
b.iceSrflxMux = nil
|
||||
err3 = m.Close()
|
||||
if b.udpMux != nil {
|
||||
m := b.udpMux
|
||||
b.udpMux = nil
|
||||
err2 = m.Close()
|
||||
}
|
||||
|
||||
if b.iceHostMux != nil {
|
||||
m := b.iceHostMux
|
||||
b.iceHostMux = nil
|
||||
err4 = m.Close()
|
||||
}
|
||||
|
||||
//todo close iceSrflxMux
|
||||
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
if err3 != nil {
|
||||
return err3
|
||||
}
|
||||
|
||||
return err4
|
||||
return err2
|
||||
}
|
||||
|
||||
// SetMark sets the mark for each packet sent through this Bind.
|
||||
|
@ -32,7 +32,7 @@ type UDPMuxDefault struct {
|
||||
conns map[string]*udpMuxedConn
|
||||
|
||||
addressMapMu sync.RWMutex
|
||||
addressMap map[string]*udpMuxedConn
|
||||
addressMap map[string][]*udpMuxedConn
|
||||
|
||||
// buffer pool to recycle buffers for net.UDPAddr encodes/decodes
|
||||
pool *sync.Pool
|
||||
@ -55,7 +55,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||
}
|
||||
|
||||
return &UDPMuxDefault{
|
||||
addressMap: map[string]*udpMuxedConn{},
|
||||
addressMap: map[string][]*udpMuxedConn{},
|
||||
params: params,
|
||||
conns: make(map[string]*udpMuxedConn),
|
||||
closedChan: make(chan struct{}, 1),
|
||||
@ -81,11 +81,19 @@ func (m *UDPMuxDefault) HandlePacket(p []byte, n int, addr net.Addr) error {
|
||||
|
||||
// If we have already seen this address dispatch to the appropriate destination
|
||||
m.addressMapMu.Lock()
|
||||
destinationConn := m.addressMap[addr.String()]
|
||||
var destinationConnList []*udpMuxedConn
|
||||
if storedConns, ok := m.addressMap[addr.String()]; ok {
|
||||
for _, conn := range storedConns {
|
||||
destinationConnList = append(destinationConnList, conn)
|
||||
}
|
||||
}
|
||||
m.addressMapMu.Unlock()
|
||||
|
||||
// If we haven't seen this address before but is a STUN packet lookup by ufrag
|
||||
if destinationConn == nil && stun.IsMessage(p[:20]) {
|
||||
if stun.IsMessage(p[:20]) {
|
||||
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
|
||||
// However, we can take a username attribute from the STUN message which contains ufrag.
|
||||
// We can use ufrag to identify the destination conn to route packet to.
|
||||
msg := &stun.Message{
|
||||
Raw: append([]byte{}, p[:n]...),
|
||||
}
|
||||
@ -96,25 +104,32 @@ func (m *UDPMuxDefault) HandlePacket(p []byte, n int, addr net.Addr) error {
|
||||
}
|
||||
|
||||
attr, stunAttrErr := msg.Get(stun.AttrUsername)
|
||||
if stunAttrErr != nil {
|
||||
log.Warnf("No Username attribute in STUN message from %s\n", addr.String())
|
||||
return stunAttrErr
|
||||
if stunAttrErr == nil {
|
||||
ufrag := strings.Split(string(attr), ":")[0]
|
||||
|
||||
m.mu.Lock()
|
||||
if destinationConn, ok := m.conns[ufrag]; ok {
|
||||
exists := false
|
||||
for _, conn := range destinationConnList {
|
||||
if conn.params.Key == destinationConn.params.Key {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
destinationConnList = append(destinationConnList, destinationConn)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
} else {
|
||||
//log.Warnf("No Username attribute in STUN message from %s\n", addr.String())
|
||||
}
|
||||
|
||||
ufrag := strings.Split(string(attr), ":")[0]
|
||||
|
||||
m.mu.Lock()
|
||||
destinationConn = m.conns[ufrag]
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
if destinationConn == nil {
|
||||
log.Tracef("dropping packet from %s, addr: %s", udpAddr.String(), addr.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := destinationConn.writePacket(p[:n], udpAddr); err != nil {
|
||||
log.Errorf("could not write packet: %v", err)
|
||||
for _, conn := range destinationConnList {
|
||||
if err := conn.writePacket(p[:n], udpAddr); err != nil {
|
||||
log.Errorf("could not write packet: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -131,6 +146,8 @@ func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
log.Debugf("ICE %s: getting muxed connection for %s", m.Type(), ufrag)
|
||||
|
||||
if m.IsClosed() {
|
||||
return nil, io.ErrClosedPipe
|
||||
}
|
||||
@ -219,7 +236,15 @@ func (m *UDPMuxDefault) removeConn(key string) {
|
||||
|
||||
addresses := c.getAddresses()
|
||||
for _, addr := range addresses {
|
||||
delete(m.addressMap, addr)
|
||||
if connList, ok := m.addressMap[addr]; ok {
|
||||
var newList []*udpMuxedConn
|
||||
for _, conn := range connList {
|
||||
if conn.params.Key != key {
|
||||
newList = append(newList, conn)
|
||||
}
|
||||
}
|
||||
m.addressMap[addr] = newList
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -236,12 +261,13 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
|
||||
defer m.addressMapMu.Unlock()
|
||||
|
||||
existing, ok := m.addressMap[addr]
|
||||
if ok {
|
||||
existing.removeAddress(addr)
|
||||
if !ok {
|
||||
existing = []*udpMuxedConn{}
|
||||
}
|
||||
m.addressMap[addr] = conn
|
||||
existing = append(existing, conn)
|
||||
m.addressMap[addr] = existing
|
||||
|
||||
m.params.Logger.Debugf("Registered %s for %s", addr, conn.params.Key)
|
||||
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
|
||||
}
|
||||
|
||||
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
|
||||
@ -252,6 +278,7 @@ func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
|
||||
LocalAddr: m.LocalAddr(),
|
||||
Logger: m.params.Logger,
|
||||
})
|
||||
log.Debugf("ICE: created muxed connection %s for %s", c.LocalAddr().String(), key)
|
||||
return c
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user