netbird/iface/bind.go

216 lines
4.5 KiB
Go
Raw Normal View History

2022-09-04 22:52:52 +02:00
package iface
import (
2022-09-06 20:06:51 +02:00
"errors"
"fmt"
"github.com/pion/stun"
log "github.com/sirupsen/logrus"
2022-09-04 22:52:52 +02:00
"golang.zx2c4.com/wireguard/conn"
"net"
"net/netip"
"sync"
2022-09-06 20:06:51 +02:00
"syscall"
2022-09-04 22:52:52 +02:00
)
2022-09-06 20:44:49 +02:00
type BindMux interface {
HandleSTUNMessage(msg *stun.Message, addr net.Addr) error
2022-09-06 20:44:49 +02:00
Type() string
}
2022-09-06 20:06:51 +02:00
type ICEBind struct {
2022-09-07 18:39:58 +02:00
sharedConn net.PacketConn
udpMux *UniversalUDPMuxDefault
iceHostMux *UDPMuxDefault
2022-09-06 20:44:49 +02:00
endpointMap map[string]net.PacketConn
2022-09-05 02:03:16 +02:00
2022-09-06 20:06:51 +02:00
mu sync.Mutex // protects following fields
2022-09-04 22:52:52 +02:00
}
2022-09-06 20:44:49 +02:00
func (b *ICEBind) GetICEMux() (UniversalUDPMux, error) {
2022-09-06 20:06:51 +02:00
b.mu.Lock()
defer b.mu.Unlock()
2022-09-07 18:39:58 +02:00
if b.udpMux == nil {
2022-09-06 20:06:51 +02:00
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
2022-09-05 02:03:16 +02:00
2022-09-07 18:39:58 +02:00
return b.udpMux, nil
2022-09-04 22:52:52 +02:00
}
2022-09-06 20:44:49 +02:00
func (b *ICEBind) GetICEHostMux() (UDPMux, error) {
2022-09-06 20:06:51 +02:00
b.mu.Lock()
defer b.mu.Unlock()
2022-09-06 20:44:49 +02:00
if b.iceHostMux == nil {
2022-09-06 20:06:51 +02:00
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
2022-09-04 22:52:52 +02:00
2022-09-06 20:44:49 +02:00
return b.iceHostMux, nil
2022-09-06 20:06:51 +02:00
}
2022-09-04 22:52:52 +02:00
2022-09-06 20:06:51 +02:00
func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
b.mu.Lock()
defer b.mu.Unlock()
2022-09-04 22:52:52 +02:00
2022-09-06 20:06:51 +02:00
if b.sharedConn != nil {
return nil, 0, conn.ErrBindAlreadyOpen
}
2022-09-06 20:44:49 +02:00
b.endpointMap = make(map[string]net.PacketConn)
2022-09-04 22:52:52 +02:00
2022-09-06 20:06:51 +02:00
port := int(uport)
ipv4Conn, port, err := listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
b.sharedConn = ipv4Conn
2022-09-07 18:39:58 +02:00
b.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.sharedConn})
2022-09-04 22:52:52 +02:00
2022-09-07 18:39:58 +02:00
portAddr1, err := netip.ParseAddrPort(ipv4Conn.LocalAddr().String())
2022-09-05 02:03:16 +02:00
if err != nil {
2022-09-06 20:06:51 +02:00
return nil, 0, err
2022-09-04 22:52:52 +02:00
}
2022-09-07 18:39:58 +02:00
log.Infof("opened ICEBind on %s", ipv4Conn.LocalAddr().String())
2022-09-06 20:44:49 +02:00
return []conn.ReceiveFunc{
2022-09-07 18:39:58 +02:00
b.makeReceiveIPv4(b.sharedConn),
2022-09-06 20:44:49 +02:00
},
2022-09-07 18:39:58 +02:00
portAddr1.Port(), nil
2022-09-06 20:06:51 +02:00
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
2022-09-05 02:03:16 +02:00
if err != nil {
2022-09-06 20:06:51 +02:00
return nil, 0, err
2022-09-04 22:52:52 +02:00
}
2022-09-06 20:06:51 +02:00
// Retrieve port.
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
2022-09-04 22:52:52 +02:00
}
2022-09-06 20:06:51 +02:00
return conn, uaddr.Port, nil
2022-09-05 02:03:16 +02:00
}
2022-09-04 22:52:52 +02:00
2022-09-07 19:02:17 +02:00
func parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: append([]byte{}, raw...),
}
if err := msg.Decode(); err != nil {
return nil, err
}
return msg, nil
}
2022-09-07 18:39:58 +02:00
func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc {
2022-09-06 20:06:51 +02:00
return func(buff []byte) (int, conn.Endpoint, error) {
n, endpoint, err := c.ReadFrom(buff)
if err != nil {
return 0, nil, err
}
e, err := netip.ParseAddrPort(endpoint.String())
if err != nil {
return 0, nil, err
}
if !stun.IsMessage(buff[:20]) {
2022-09-06 20:06:51 +02:00
// WireGuard traffic
return n, (*conn.StdNetEndpoint)(&net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}), nil
}
2022-09-07 18:39:58 +02:00
2022-09-07 19:02:17 +02:00
msg, err := parseSTUNMessage(buff[:n])
if err != nil {
return 0, nil, err
}
b.mu.Lock()
2022-09-06 20:44:49 +02:00
if _, ok := b.endpointMap[e.String()]; !ok {
b.endpointMap[e.String()] = c
2022-09-07 18:39:58 +02:00
log.Infof("added endpoint %s", e.String())
2022-09-06 20:44:49 +02:00
}
b.mu.Unlock()
err = b.udpMux.HandleSTUNMessage(msg, endpoint)
2022-09-06 20:06:51 +02:00
if err != nil {
return 0, nil, err
2022-09-04 22:52:52 +02:00
}
2022-09-06 20:06:51 +02:00
if err != nil {
log.Warnf("failed to handle packet")
}
// discard packets because they are STUN related
return 0, nil, nil //todo proper return
2022-09-04 22:52:52 +02:00
}
2022-09-06 20:06:51 +02:00
}
func (b *ICEBind) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
2022-09-07 18:39:58 +02:00
var err1, err2 error
2022-09-06 20:06:51 +02:00
if b.sharedConn != nil {
c := b.sharedConn
b.sharedConn = nil
err1 = c.Close()
}
2022-09-07 18:39:58 +02:00
if b.udpMux != nil {
m := b.udpMux
b.udpMux = nil
err2 = m.Close()
2022-09-06 20:06:51 +02:00
}
if err1 != nil {
return err1
}
2022-09-06 20:44:49 +02:00
2022-09-07 18:39:58 +02:00
return err2
2022-09-04 22:52:52 +02:00
}
// SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK.
2022-09-06 20:06:51 +02:00
func (b *ICEBind) SetMark(mark uint32) error {
2022-09-04 22:52:52 +02:00
return nil
}
2022-09-06 20:06:51 +02:00
func (b *ICEBind) Send(buff []byte, endpoint conn.Endpoint) error {
2022-09-04 22:52:52 +02:00
nend, ok := endpoint.(*conn.StdNetEndpoint)
if !ok {
return conn.ErrWrongEndpointType
}
2022-09-06 20:44:49 +02:00
b.mu.Lock()
co := b.endpointMap[(*net.UDPAddr)(nend).String()]
2022-09-07 11:17:54 +02:00
2022-09-06 20:44:49 +02:00
if co == nil {
// todo proper handling
2022-09-07 11:17:54 +02:00
// todo without it relayed connections didn't work. investigate
2022-09-06 20:59:19 +02:00
log.Warnf("conn not found for endpoint %s", (*net.UDPAddr)(nend).String())
2022-09-07 11:17:54 +02:00
co = b.sharedConn
b.endpointMap[(*net.UDPAddr)(nend).String()] = b.sharedConn
//return conn.ErrWrongEndpointType
2022-09-06 20:44:49 +02:00
}
2022-09-07 11:17:54 +02:00
b.mu.Unlock()
2022-09-06 20:44:49 +02:00
_, err := co.WriteTo(buff, (*net.UDPAddr)(nend))
2022-09-04 22:52:52 +02:00
return err
}
// ParseEndpoint creates a new endpoint from a string.
2022-09-06 20:06:51 +02:00
func (b *ICEBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) {
2022-09-05 02:03:16 +02:00
e, err := netip.ParseAddrPort(s)
2022-09-04 22:52:52 +02:00
return (*conn.StdNetEndpoint)(&net.UDPAddr{
2022-09-05 02:03:16 +02:00
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
2022-09-04 22:52:52 +02:00
}), err
}