mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-12 18:00:49 +01:00
2eeed55c18
This PR adds supports for the WireGuard userspace implementation using Bind interface from wireguard-go. The newly introduced ICEBind struct implements Bind with UDPMux-based structs from pion/ice to handle hole punching using ICE. The core implementation was taken from StdBind of wireguard-go. The result is a single WireGuard port that is used for host and server reflexive candidates. Relay candidates are still handled separately and will be integrated in the following PRs. ICEBind checks the incoming packets for being STUN or WireGuard ones and routes them to UDPMux (to handle hole punching) or to WireGuard respectively.
234 lines
5.0 KiB
Go
234 lines
5.0 KiB
Go
package bind
|
|
|
|
/*
|
|
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
|
|
*/
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/pion/logging"
|
|
"github.com/pion/transport/v2/packetio"
|
|
)
|
|
|
|
type udpMuxedConnParams struct {
|
|
Mux *UDPMuxDefault
|
|
AddrPool *sync.Pool
|
|
Key string
|
|
LocalAddr net.Addr
|
|
Logger logging.LeveledLogger
|
|
}
|
|
|
|
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
|
|
type udpMuxedConn struct {
|
|
params *udpMuxedConnParams
|
|
// remote addresses that we have sent to on this conn
|
|
addresses []string
|
|
|
|
// channel holding incoming packets
|
|
buf *packetio.Buffer
|
|
closedChan chan struct{}
|
|
closeOnce sync.Once
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn {
|
|
p := &udpMuxedConn{
|
|
params: params,
|
|
buf: packetio.NewBuffer(),
|
|
closedChan: make(chan struct{}),
|
|
}
|
|
|
|
return p
|
|
}
|
|
|
|
func (c *udpMuxedConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) {
|
|
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
|
|
defer c.params.AddrPool.Put(buf)
|
|
|
|
// read address
|
|
total, err := c.buf.Read(buf.buf)
|
|
if err != nil {
|
|
return 0, nil, err
|
|
}
|
|
|
|
dataLen := int(binary.LittleEndian.Uint16(buf.buf[:2]))
|
|
if dataLen > total || dataLen > len(b) {
|
|
return 0, nil, io.ErrShortBuffer
|
|
}
|
|
|
|
// read data and then address
|
|
offset := 2
|
|
copy(b, buf.buf[offset:offset+dataLen])
|
|
offset += dataLen
|
|
|
|
// read address len & decode address
|
|
addrLen := int(binary.LittleEndian.Uint16(buf.buf[offset : offset+2]))
|
|
offset += 2
|
|
|
|
if rAddr, err = decodeUDPAddr(buf.buf[offset : offset+addrLen]); err != nil {
|
|
return 0, nil, err
|
|
}
|
|
|
|
return dataLen, rAddr, nil
|
|
}
|
|
|
|
func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
|
if c.isClosed() {
|
|
return 0, io.ErrClosedPipe
|
|
}
|
|
// each time we write to a new address, we'll register it with the mux
|
|
addr := rAddr.String()
|
|
if !c.containsAddress(addr) {
|
|
c.addAddress(addr)
|
|
}
|
|
|
|
return c.params.Mux.writeTo(buf, rAddr)
|
|
}
|
|
|
|
func (c *udpMuxedConn) LocalAddr() net.Addr {
|
|
return c.params.LocalAddr
|
|
}
|
|
|
|
func (c *udpMuxedConn) SetDeadline(tm time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func (c *udpMuxedConn) SetReadDeadline(tm time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func (c *udpMuxedConn) SetWriteDeadline(tm time.Time) error {
|
|
return nil
|
|
}
|
|
|
|
func (c *udpMuxedConn) CloseChannel() <-chan struct{} {
|
|
return c.closedChan
|
|
}
|
|
|
|
func (c *udpMuxedConn) Close() error {
|
|
var err error
|
|
c.closeOnce.Do(func() {
|
|
err = c.buf.Close()
|
|
close(c.closedChan)
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (c *udpMuxedConn) isClosed() bool {
|
|
select {
|
|
case <-c.closedChan:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (c *udpMuxedConn) getAddresses() []string {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
addresses := make([]string, len(c.addresses))
|
|
copy(addresses, c.addresses)
|
|
return addresses
|
|
}
|
|
|
|
func (c *udpMuxedConn) addAddress(addr string) {
|
|
c.mu.Lock()
|
|
c.addresses = append(c.addresses, addr)
|
|
c.mu.Unlock()
|
|
|
|
// map it on mux
|
|
c.params.Mux.registerConnForAddress(c, addr)
|
|
}
|
|
|
|
func (c *udpMuxedConn) containsAddress(addr string) bool {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
for _, a := range c.addresses {
|
|
if addr == a {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error {
|
|
// write two packets, address and data
|
|
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
|
|
defer c.params.AddrPool.Put(buf)
|
|
|
|
// format of buffer | data len | data bytes | addr len | addr bytes |
|
|
if len(buf.buf) < len(data)+maxAddrSize {
|
|
return io.ErrShortBuffer
|
|
}
|
|
// data len
|
|
binary.LittleEndian.PutUint16(buf.buf, uint16(len(data)))
|
|
offset := 2
|
|
|
|
// data
|
|
copy(buf.buf[offset:], data)
|
|
offset += len(data)
|
|
|
|
// write address first, leaving room for its length
|
|
n, err := encodeUDPAddr(addr, buf.buf[offset+2:])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
total := offset + n + 2
|
|
|
|
// address len
|
|
binary.LittleEndian.PutUint16(buf.buf[offset:], uint16(n))
|
|
|
|
if _, err := c.buf.Write(buf.buf[:total]); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func encodeUDPAddr(addr *net.UDPAddr, buf []byte) (int, error) {
|
|
ipData, err := addr.IP.MarshalText()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
total := 2 + len(ipData) + 2 + len(addr.Zone)
|
|
if total > len(buf) {
|
|
return 0, io.ErrShortBuffer
|
|
}
|
|
|
|
binary.LittleEndian.PutUint16(buf, uint16(len(ipData)))
|
|
offset := 2
|
|
n := copy(buf[offset:], ipData)
|
|
offset += n
|
|
binary.LittleEndian.PutUint16(buf[offset:], uint16(addr.Port))
|
|
offset += 2
|
|
copy(buf[offset:], addr.Zone)
|
|
return total, nil
|
|
}
|
|
|
|
func decodeUDPAddr(buf []byte) (*net.UDPAddr, error) {
|
|
addr := net.UDPAddr{}
|
|
|
|
offset := 0
|
|
ipLen := int(binary.LittleEndian.Uint16(buf[:2]))
|
|
offset += 2
|
|
// basic bounds checking
|
|
if ipLen+offset > len(buf) {
|
|
return nil, io.ErrShortBuffer
|
|
}
|
|
if err := addr.IP.UnmarshalText(buf[offset : offset+ipLen]); err != nil {
|
|
return nil, err
|
|
}
|
|
offset += ipLen
|
|
addr.Port = int(binary.LittleEndian.Uint16(buf[offset : offset+2]))
|
|
offset += 2
|
|
zone := make([]byte, len(buf[offset:]))
|
|
copy(zone, buf[offset:])
|
|
addr.Zone = string(zone)
|
|
|
|
return &addr, nil
|
|
}
|