mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-25 17:43:38 +01:00
234 lines
5.0 KiB
Go
234 lines
5.0 KiB
Go
|
package bind
|
||
|
|
||
|
/*
|
||
|
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
|
||
|
*/
|
||
|
|
||
|
import (
|
||
|
"encoding/binary"
|
||
|
"io"
|
||
|
"net"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/pion/logging"
|
||
|
"github.com/pion/transport/v2/packetio"
|
||
|
)
|
||
|
|
||
|
type udpMuxedConnParams struct {
|
||
|
Mux *UDPMuxDefault
|
||
|
AddrPool *sync.Pool
|
||
|
Key string
|
||
|
LocalAddr net.Addr
|
||
|
Logger logging.LeveledLogger
|
||
|
}
|
||
|
|
||
|
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
|
||
|
type udpMuxedConn struct {
|
||
|
params *udpMuxedConnParams
|
||
|
// remote addresses that we have sent to on this conn
|
||
|
addresses []string
|
||
|
|
||
|
// channel holding incoming packets
|
||
|
buf *packetio.Buffer
|
||
|
closedChan chan struct{}
|
||
|
closeOnce sync.Once
|
||
|
mu sync.Mutex
|
||
|
}
|
||
|
|
||
|
func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn {
|
||
|
p := &udpMuxedConn{
|
||
|
params: params,
|
||
|
buf: packetio.NewBuffer(),
|
||
|
closedChan: make(chan struct{}),
|
||
|
}
|
||
|
|
||
|
return p
|
||
|
}
|
||
|
|
||
|
func (c *udpMuxedConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) {
|
||
|
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
|
||
|
defer c.params.AddrPool.Put(buf)
|
||
|
|
||
|
// read address
|
||
|
total, err := c.buf.Read(buf.buf)
|
||
|
if err != nil {
|
||
|
return 0, nil, err
|
||
|
}
|
||
|
|
||
|
dataLen := int(binary.LittleEndian.Uint16(buf.buf[:2]))
|
||
|
if dataLen > total || dataLen > len(b) {
|
||
|
return 0, nil, io.ErrShortBuffer
|
||
|
}
|
||
|
|
||
|
// read data and then address
|
||
|
offset := 2
|
||
|
copy(b, buf.buf[offset:offset+dataLen])
|
||
|
offset += dataLen
|
||
|
|
||
|
// read address len & decode address
|
||
|
addrLen := int(binary.LittleEndian.Uint16(buf.buf[offset : offset+2]))
|
||
|
offset += 2
|
||
|
|
||
|
if rAddr, err = decodeUDPAddr(buf.buf[offset : offset+addrLen]); err != nil {
|
||
|
return 0, nil, err
|
||
|
}
|
||
|
|
||
|
return dataLen, rAddr, nil
|
||
|
}
|
||
|
|
||
|
func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
||
|
if c.isClosed() {
|
||
|
return 0, io.ErrClosedPipe
|
||
|
}
|
||
|
// each time we write to a new address, we'll register it with the mux
|
||
|
addr := rAddr.String()
|
||
|
if !c.containsAddress(addr) {
|
||
|
c.addAddress(addr)
|
||
|
}
|
||
|
|
||
|
return c.params.Mux.writeTo(buf, rAddr)
|
||
|
}
|
||
|
|
||
|
func (c *udpMuxedConn) LocalAddr() net.Addr {
|
||
|
return c.params.LocalAddr
|
||
|
}
|
||
|
|
||
|
func (c *udpMuxedConn) SetDeadline(tm time.Time) error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *udpMuxedConn) SetReadDeadline(tm time.Time) error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *udpMuxedConn) SetWriteDeadline(tm time.Time) error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *udpMuxedConn) CloseChannel() <-chan struct{} {
|
||
|
return c.closedChan
|
||
|
}
|
||
|
|
||
|
func (c *udpMuxedConn) Close() error {
|
||
|
var err error
|
||
|
c.closeOnce.Do(func() {
|
||
|
err = c.buf.Close()
|
||
|
close(c.closedChan)
|
||
|
})
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (c *udpMuxedConn) isClosed() bool {
|
||
|
select {
|
||
|
case <-c.closedChan:
|
||
|
return true
|
||
|
default:
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (c *udpMuxedConn) getAddresses() []string {
|
||
|
c.mu.Lock()
|
||
|
defer c.mu.Unlock()
|
||
|
addresses := make([]string, len(c.addresses))
|
||
|
copy(addresses, c.addresses)
|
||
|
return addresses
|
||
|
}
|
||
|
|
||
|
func (c *udpMuxedConn) addAddress(addr string) {
|
||
|
c.mu.Lock()
|
||
|
c.addresses = append(c.addresses, addr)
|
||
|
c.mu.Unlock()
|
||
|
|
||
|
// map it on mux
|
||
|
c.params.Mux.registerConnForAddress(c, addr)
|
||
|
}
|
||
|
|
||
|
func (c *udpMuxedConn) containsAddress(addr string) bool {
|
||
|
c.mu.Lock()
|
||
|
defer c.mu.Unlock()
|
||
|
for _, a := range c.addresses {
|
||
|
if addr == a {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error {
|
||
|
// write two packets, address and data
|
||
|
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
|
||
|
defer c.params.AddrPool.Put(buf)
|
||
|
|
||
|
// format of buffer | data len | data bytes | addr len | addr bytes |
|
||
|
if len(buf.buf) < len(data)+maxAddrSize {
|
||
|
return io.ErrShortBuffer
|
||
|
}
|
||
|
// data len
|
||
|
binary.LittleEndian.PutUint16(buf.buf, uint16(len(data)))
|
||
|
offset := 2
|
||
|
|
||
|
// data
|
||
|
copy(buf.buf[offset:], data)
|
||
|
offset += len(data)
|
||
|
|
||
|
// write address first, leaving room for its length
|
||
|
n, err := encodeUDPAddr(addr, buf.buf[offset+2:])
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
total := offset + n + 2
|
||
|
|
||
|
// address len
|
||
|
binary.LittleEndian.PutUint16(buf.buf[offset:], uint16(n))
|
||
|
|
||
|
if _, err := c.buf.Write(buf.buf[:total]); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func encodeUDPAddr(addr *net.UDPAddr, buf []byte) (int, error) {
|
||
|
ipData, err := addr.IP.MarshalText()
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
total := 2 + len(ipData) + 2 + len(addr.Zone)
|
||
|
if total > len(buf) {
|
||
|
return 0, io.ErrShortBuffer
|
||
|
}
|
||
|
|
||
|
binary.LittleEndian.PutUint16(buf, uint16(len(ipData)))
|
||
|
offset := 2
|
||
|
n := copy(buf[offset:], ipData)
|
||
|
offset += n
|
||
|
binary.LittleEndian.PutUint16(buf[offset:], uint16(addr.Port))
|
||
|
offset += 2
|
||
|
copy(buf[offset:], addr.Zone)
|
||
|
return total, nil
|
||
|
}
|
||
|
|
||
|
func decodeUDPAddr(buf []byte) (*net.UDPAddr, error) {
|
||
|
addr := net.UDPAddr{}
|
||
|
|
||
|
offset := 0
|
||
|
ipLen := int(binary.LittleEndian.Uint16(buf[:2]))
|
||
|
offset += 2
|
||
|
// basic bounds checking
|
||
|
if ipLen+offset > len(buf) {
|
||
|
return nil, io.ErrShortBuffer
|
||
|
}
|
||
|
if err := addr.IP.UnmarshalText(buf[offset : offset+ipLen]); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
offset += ipLen
|
||
|
addr.Port = int(binary.LittleEndian.Uint16(buf[offset : offset+2]))
|
||
|
offset += 2
|
||
|
zone := make([]byte, len(buf[offset:]))
|
||
|
copy(zone, buf[offset:])
|
||
|
addr.Zone = string(zone)
|
||
|
|
||
|
return &addr, nil
|
||
|
}
|