mirror of
synced 2025-03-27 08:07:51 +01:00
This PR adds supports for the WireGuard userspace implementation using Bind interface from wireguard-go. The newly introduced ICEBind struct implements Bind with UDPMux-based structs from pion/ice to handle hole punching using ICE. The core implementation was taken from StdBind of wireguard-go. The result is a single WireGuard port that is used for host and server reflexive candidates. Relay candidates are still handled separately and will be integrated in the following PRs. ICEBind checks the incoming packets for being STUN or WireGuard ones and routes them to UDPMux (to handle hole punching) or to WireGuard respectively.
234 lines
5.0 KiB
234 lines
5.0 KiB
package bind
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
import (
type udpMuxedConnParams struct {
Mux *UDPMuxDefault
AddrPool *sync.Pool
Key string
LocalAddr net.Addr
Logger logging.LeveledLogger
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
type udpMuxedConn struct {
params *udpMuxedConnParams
// remote addresses that we have sent to on this conn
addresses []string
// channel holding incoming packets
buf *packetio.Buffer
closedChan chan struct{}
closeOnce sync.Once
mu sync.Mutex
func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn {
p := &udpMuxedConn{
params: params,
buf: packetio.NewBuffer(),
closedChan: make(chan struct{}),
return p
func (c *udpMuxedConn) ReadFrom(b []byte) (n int, rAddr net.Addr, err error) {
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
defer c.params.AddrPool.Put(buf)
// read address
total, err := c.buf.Read(buf.buf)
if err != nil {
return 0, nil, err
dataLen := int(binary.LittleEndian.Uint16(buf.buf[:2]))
if dataLen > total || dataLen > len(b) {
return 0, nil, io.ErrShortBuffer
// read data and then address
offset := 2
copy(b, buf.buf[offset:offset+dataLen])
offset += dataLen
// read address len & decode address
addrLen := int(binary.LittleEndian.Uint16(buf.buf[offset : offset+2]))
offset += 2
if rAddr, err = decodeUDPAddr(buf.buf[offset : offset+addrLen]); err != nil {
return 0, nil, err
return dataLen, rAddr, nil
func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
if c.isClosed() {
return 0, io.ErrClosedPipe
// each time we write to a new address, we'll register it with the mux
addr := rAddr.String()
if !c.containsAddress(addr) {
return c.params.Mux.writeTo(buf, rAddr)
func (c *udpMuxedConn) LocalAddr() net.Addr {
return c.params.LocalAddr
func (c *udpMuxedConn) SetDeadline(tm time.Time) error {
return nil
func (c *udpMuxedConn) SetReadDeadline(tm time.Time) error {
return nil
func (c *udpMuxedConn) SetWriteDeadline(tm time.Time) error {
return nil
func (c *udpMuxedConn) CloseChannel() <-chan struct{} {
return c.closedChan
func (c *udpMuxedConn) Close() error {
var err error
c.closeOnce.Do(func() {
err = c.buf.Close()
return err
func (c *udpMuxedConn) isClosed() bool {
select {
case <-c.closedChan:
return true
return false
func (c *udpMuxedConn) getAddresses() []string {
defer c.mu.Unlock()
addresses := make([]string, len(c.addresses))
copy(addresses, c.addresses)
return addresses
func (c *udpMuxedConn) addAddress(addr string) {
c.addresses = append(c.addresses, addr)
// map it on mux
c.params.Mux.registerConnForAddress(c, addr)
func (c *udpMuxedConn) containsAddress(addr string) bool {
defer c.mu.Unlock()
for _, a := range c.addresses {
if addr == a {
return true
return false
func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error {
// write two packets, address and data
buf := c.params.AddrPool.Get().(*bufferHolder) //nolint:forcetypeassert
defer c.params.AddrPool.Put(buf)
// format of buffer | data len | data bytes | addr len | addr bytes |
if len(buf.buf) < len(data)+maxAddrSize {
return io.ErrShortBuffer
// data len
binary.LittleEndian.PutUint16(buf.buf, uint16(len(data)))
offset := 2
// data
copy(buf.buf[offset:], data)
offset += len(data)
// write address first, leaving room for its length
n, err := encodeUDPAddr(addr, buf.buf[offset+2:])
if err != nil {
return err
total := offset + n + 2
// address len
binary.LittleEndian.PutUint16(buf.buf[offset:], uint16(n))
if _, err := c.buf.Write(buf.buf[:total]); err != nil {
return err
return nil
func encodeUDPAddr(addr *net.UDPAddr, buf []byte) (int, error) {
ipData, err := addr.IP.MarshalText()
if err != nil {
return 0, err
total := 2 + len(ipData) + 2 + len(addr.Zone)
if total > len(buf) {
return 0, io.ErrShortBuffer
binary.LittleEndian.PutUint16(buf, uint16(len(ipData)))
offset := 2
n := copy(buf[offset:], ipData)
offset += n
binary.LittleEndian.PutUint16(buf[offset:], uint16(addr.Port))
offset += 2
copy(buf[offset:], addr.Zone)
return total, nil
func decodeUDPAddr(buf []byte) (*net.UDPAddr, error) {
addr := net.UDPAddr{}
offset := 0
ipLen := int(binary.LittleEndian.Uint16(buf[:2]))
offset += 2
// basic bounds checking
if ipLen+offset > len(buf) {
return nil, io.ErrShortBuffer
if err := addr.IP.UnmarshalText(buf[offset : offset+ipLen]); err != nil {
return nil, err
offset += ipLen
addr.Port = int(binary.LittleEndian.Uint16(buf[offset : offset+2]))
offset += 2
zone := make([]byte, len(buf[offset:]))
copy(zone, buf[offset:])
addr.Zone = string(zone)
return &addr, nil