diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index a9c25950d..d889a684b 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -1,6 +1,7 @@ package bind import ( + "errors" "fmt" "net" "net/netip" @@ -12,6 +13,8 @@ import ( "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "golang.org/x/sys/unix" wgConn "golang.zx2c4.com/wireguard/conn" ) @@ -24,8 +27,12 @@ type receiverCreator struct { iceBind *ICEBind } -func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc { - return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn) +const ( + udpSegmentMaxDatagrams = 64 +) + +func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) wgConn.ReceiveFunc { + return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload) } // ICEBind is a bind implementation with two main features: @@ -51,6 +58,7 @@ type ICEBind struct { muUDPMux sync.Mutex udpMux *UniversalUDPMuxDefault + msgsPool sync.Pool } func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind { @@ -63,11 +71,24 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind { endpoints: make(map[netip.Addr]net.Conn), closedChan: make(chan struct{}), closed: true, + msgsPool: sync.Pool{ + New: func() any { + // ipv6.Message and ipv4.Message are interchangeable as they are + // both aliases for x/net/internal/socket.Message. + msgs := make([]ipv6.Message, wgConn.IdealBatchSize) + for i := range msgs { + msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].OOB = make([]byte, 0, wgConn.StickyControlSize+unix.CmsgSpace(2)) + } + return &msgs + }, + }, } rc := receiverCreator{ ib, } + ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc) return ib } @@ -154,7 +175,7 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { return nil } -func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc { +func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) wgConn.ReceiveFunc { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() @@ -165,44 +186,83 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC FilterFn: s.filterFn, }, ) + return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { - msgs := ipv4MsgsPool.Get().(*[]ipv4.Message) - defer ipv4MsgsPool.Put(msgs) - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + } +} + +func (s *ICEBind) receiveIP( + br batchReader, + conn *net.UDPConn, + rxOffload bool, + bufs [][]byte, + sizes []int, + eps []wgConn.Endpoint) (n int, err error) { + + msgs := s.msgsPool.Get().(*[]ipv6.Message) + for i := range bufs { + (*msgs)[i].Buffers[0] = bufs[i] + (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] + } + defer func() { + for i := range *msgs { + (*msgs)[i].OOB = (*msgs)[i].OOB[:0] + (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} } - var numMsgs int - if runtime.GOOS == "linux" { - numMsgs, err = pc.ReadBatch(*msgs, 0) + s.msgsPool.Put(msgs) + }() + var numMsgs int + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + if rxOffload { + readAt := len(*msgs) - (wgConn.IdealBatchSize / udpSegmentMaxDatagrams) + numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) + if err != nil { + return 0, err + } + numMsgs, err = splitCoalescedMessages(*msgs, readAt, wgConn.GetGSOSize) if err != nil { return 0, err } } else { - msg := &(*msgs)[0] - msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + numMsgs, err = br.ReadBatch(*msgs, 0) if err != nil { return 0, err } - numMsgs = 1 } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - - // todo: handle err - ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr) - if ok { - sizes[i] = 0 - } else { - sizes[i] = msg.N - } - - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation - wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep) - eps[i] = ep + } else { + msg := &(*msgs)[0] + msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + if err != nil { + return 0, err } - return numMsgs, nil + numMsgs = 1 } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + + // todo: handle err + ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr) + if ok { + sizes[i] = 0 + } else { + sizes[i] = msg.N + } + + sizes[i] = msg.N + if sizes[i] == 0 { + continue + } + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation + wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep) + eps[i] = ep + } + return numMsgs, nil +} + +type batchReader interface { + ReadBatch([]ipv6.Message, int) (int, error) } func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) { @@ -273,3 +333,49 @@ func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) { } return newAddr, nil } + +type getGSOFunc func(control []byte) (int, error) + +func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { + for i := firstMsgAt; i < len(msgs); i++ { + msg := &msgs[i] + if msg.N == 0 { + return n, err + } + var ( + gsoSize int + start int + end = msg.N + numToSplit = 1 + ) + gsoSize, err = getGSO(msg.OOB[:msg.NN]) + if err != nil { + return n, err + } + if gsoSize > 0 { + numToSplit = (msg.N + gsoSize - 1) / gsoSize + end = gsoSize + } + for j := 0; j < numToSplit; j++ { + if n > i { + return n, errors.New("splitting coalesced packet resulted in overflow") + } + copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) + msgs[n].N = copied + msgs[n].Addr = msg.Addr + start = end + end += gsoSize + if end > msg.N { + end = msg.N + } + n++ + } + if i != n-1 { + // It is legal for bytes to move within msg.Buffers[0] as a result + // of splitting, so we only zero the source msg len when it is not + // the destination of the last split operation above. + msg.N = 0 + } + } + return n, nil +} diff --git a/go.mod b/go.mod index 7223a446b..21a65bafe 100644 --- a/go.mod +++ b/go.mod @@ -239,7 +239,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 -replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed +replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241030151402-7d9b34e13fb8 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 diff --git a/go.sum b/go.sum index 5cd703bc8..54a9ceb2f 100644 --- a/go.sum +++ b/go.sum @@ -527,8 +527,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= -github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= -github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= +github.com/netbirdio/wireguard-go v0.0.0-20241030151402-7d9b34e13fb8 h1:xalM0+6jQ8mhUzkfTg9I505misQSH83J/IgK5oPQUSg= +github.com/netbirdio/wireguard-go v0.0.0-20241030151402-7d9b34e13fb8/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=