mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-05 21:49:03 +01:00
4be826450b
Improve the performance on Linux and Android in case of P2P connections
304 lines
7.7 KiB
Go
304 lines
7.7 KiB
Go
package bind
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/pion/stun/v2"
|
|
"github.com/pion/transport/v3"
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/net/ipv4"
|
|
"golang.org/x/net/ipv6"
|
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
|
)
|
|
|
|
type RecvMessage struct {
|
|
Endpoint *Endpoint
|
|
Buffer []byte
|
|
}
|
|
|
|
type receiverCreator struct {
|
|
iceBind *ICEBind
|
|
}
|
|
|
|
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
|
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
|
|
}
|
|
|
|
// ICEBind is a bind implementation with two main features:
|
|
// 1. filter out STUN messages and handle them
|
|
// 2. forward the received packets to the WireGuard interface from the relayed connection
|
|
//
|
|
// ICEBind.endpoints var is a map that stores the connection for each relayed peer. Fake address is just an IP address
|
|
// without port, in the format of 127.1.x.x where x.x is the last two octets of the peer address. We try to avoid to
|
|
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
|
|
type ICEBind struct {
|
|
*wgConn.StdNetBind
|
|
RecvChan chan RecvMessage
|
|
|
|
transportNet transport.Net
|
|
filterFn FilterFn
|
|
endpoints map[netip.Addr]net.Conn
|
|
endpointsMu sync.Mutex
|
|
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
|
|
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
|
|
closedChan chan struct{}
|
|
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
|
closed bool
|
|
|
|
muUDPMux sync.Mutex
|
|
udpMux *UniversalUDPMuxDefault
|
|
}
|
|
|
|
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
|
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
|
ib := &ICEBind{
|
|
StdNetBind: b,
|
|
RecvChan: make(chan RecvMessage, 1),
|
|
transportNet: transportNet,
|
|
filterFn: filterFn,
|
|
endpoints: make(map[netip.Addr]net.Conn),
|
|
closedChan: make(chan struct{}),
|
|
closed: true,
|
|
}
|
|
|
|
rc := receiverCreator{
|
|
ib,
|
|
}
|
|
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
|
|
return ib
|
|
}
|
|
|
|
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
|
s.closed = false
|
|
s.closedChanMu.Lock()
|
|
s.closedChan = make(chan struct{})
|
|
s.closedChanMu.Unlock()
|
|
fns, port, err := s.StdNetBind.Open(uport)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
fns = append(fns, s.receiveRelayed)
|
|
return fns, port, nil
|
|
}
|
|
|
|
func (s *ICEBind) Close() error {
|
|
if s.closed {
|
|
return nil
|
|
}
|
|
s.closed = true
|
|
|
|
close(s.closedChan)
|
|
|
|
return s.StdNetBind.Close()
|
|
}
|
|
|
|
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
|
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
|
s.muUDPMux.Lock()
|
|
defer s.muUDPMux.Unlock()
|
|
if s.udpMux == nil {
|
|
return nil, fmt.Errorf("ICEBind has not been initialized yet")
|
|
}
|
|
|
|
return s.udpMux, nil
|
|
}
|
|
|
|
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
|
|
fakeUDPAddr, err := fakeAddress(peerAddress)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// force IPv4
|
|
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
|
if !ok {
|
|
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
|
|
}
|
|
|
|
b.endpointsMu.Lock()
|
|
b.endpoints[fakeAddr] = conn
|
|
b.endpointsMu.Unlock()
|
|
|
|
return fakeUDPAddr, nil
|
|
}
|
|
|
|
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
|
|
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
|
if !ok {
|
|
log.Warnf("failed to convert IP to netip.Addr")
|
|
return
|
|
}
|
|
|
|
b.endpointsMu.Lock()
|
|
defer b.endpointsMu.Unlock()
|
|
delete(b.endpoints, fakeAddr)
|
|
}
|
|
|
|
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
|
b.endpointsMu.Lock()
|
|
conn, ok := b.endpoints[ep.DstIP()]
|
|
b.endpointsMu.Unlock()
|
|
if !ok {
|
|
return b.StdNetBind.Send(bufs, ep)
|
|
}
|
|
|
|
for _, buf := range bufs {
|
|
if _, err := conn.Write(buf); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
|
s.muUDPMux.Lock()
|
|
defer s.muUDPMux.Unlock()
|
|
|
|
s.udpMux = NewUniversalUDPMuxDefault(
|
|
UniversalUDPMuxParams{
|
|
UDPConn: conn,
|
|
Net: s.transportNet,
|
|
FilterFn: s.filterFn,
|
|
},
|
|
)
|
|
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
|
msgs := getMessages(msgsPool)
|
|
for i := range bufs {
|
|
(*msgs)[i].Buffers[0] = bufs[i]
|
|
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
|
}
|
|
defer putMessages(msgs, msgsPool)
|
|
var numMsgs int
|
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
|
if rxOffload {
|
|
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
|
|
//nolint
|
|
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
numMsgs, err = wgConn.SplitCoalescedMessages(*msgs, readAt, wgConn.GetGSOSize)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
} else {
|
|
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
} else {
|
|
msg := &(*msgs)[0]
|
|
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
numMsgs = 1
|
|
}
|
|
for i := 0; i < numMsgs; i++ {
|
|
msg := &(*msgs)[i]
|
|
|
|
// todo: handle err
|
|
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
|
if ok {
|
|
continue
|
|
}
|
|
sizes[i] = msg.N
|
|
if sizes[i] == 0 {
|
|
continue
|
|
}
|
|
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
|
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
|
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
|
eps[i] = ep
|
|
}
|
|
return numMsgs, nil
|
|
}
|
|
}
|
|
|
|
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
|
|
for i := range buffers {
|
|
if !stun.IsMessage(buffers[i]) {
|
|
continue
|
|
}
|
|
|
|
msg, err := s.parseSTUNMessage(buffers[i][:n])
|
|
if err != nil {
|
|
buffers[i] = []byte{}
|
|
return true, err
|
|
}
|
|
|
|
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
|
|
if muxErr != nil {
|
|
log.Warnf("failed to handle STUN packet")
|
|
}
|
|
|
|
buffers[i] = []byte{}
|
|
return true, nil
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {
|
|
msg := &stun.Message{
|
|
Raw: raw,
|
|
}
|
|
if err := msg.Decode(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return msg, nil
|
|
}
|
|
|
|
// receiveRelayed is a receive function that is used to receive packets from the relayed connection and forward to the
|
|
// WireGuard. Critical part is do not block if the Closed() has been called.
|
|
func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
|
c.closedChanMu.RLock()
|
|
defer c.closedChanMu.RUnlock()
|
|
|
|
select {
|
|
case <-c.closedChan:
|
|
return 0, net.ErrClosed
|
|
case msg, ok := <-c.RecvChan:
|
|
if !ok {
|
|
return 0, net.ErrClosed
|
|
}
|
|
copy(buffs[0], msg.Buffer)
|
|
sizes[0] = len(msg.Buffer)
|
|
eps[0] = wgConn.Endpoint(msg.Endpoint)
|
|
return 1, nil
|
|
}
|
|
}
|
|
|
|
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
|
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
|
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
|
|
octets := strings.Split(peerAddress.IP.String(), ".")
|
|
if len(octets) != 4 {
|
|
return nil, fmt.Errorf("invalid IP format")
|
|
}
|
|
|
|
newAddr := &net.UDPAddr{
|
|
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
|
|
Port: peerAddress.Port,
|
|
}
|
|
return newAddr, nil
|
|
}
|
|
|
|
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
|
|
return msgsPool.Get().(*[]ipv6.Message)
|
|
}
|
|
|
|
func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
|
|
for i := range *msgs {
|
|
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
|
|
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
|
|
}
|
|
msgsPool.Put(msgs)
|
|
}
|