//go:build linux && !android // Inspired by // Jason Donenfeld (https://git.zx2c4.com/wireguard-tools/tree/contrib/nat-hole-punching/nat-punch-client.c#n96) // and @stv0g in https://github.com/stv0g/cunicu/tree/ebpf-poc/ebpf_poc package sharedsock import ( "context" "fmt" "net" "sync" "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/google/gopacket/routing" "github.com/libp2p/go-netroute" "github.com/mdlayher/socket" log "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" nbnet "github.com/netbirdio/netbird/util/net" ) // ErrSharedSockStopped indicates that shared socket has been stopped var ErrSharedSockStopped = fmt.Errorf("shared socked stopped") // SharedSocket is a net.PacketConn that initiates two raw sockets (ipv4 and ipv6) and listens to UDP packets filtered // by BPF instructions (e.g., IncomingSTUNFilter that checks and sends only STUN packets to the listeners (ReadFrom)). // It is meant to be used when sharing a port with some other process. type SharedSocket struct { ctx context.Context conn4 *socket.Conn conn6 *socket.Conn port int routerMux sync.RWMutex router routing.Router packetDemux chan rcvdPacket cancel context.CancelFunc } type rcvdPacket struct { n int addr unix.Sockaddr buf []byte err error } type receiver func(ctx context.Context, p []byte, flags int) (int, unix.Sockaddr, error) var writeSerializerOptions = gopacket.SerializeOptions{ ComputeChecksums: true, FixLengths: true, } // Listen creates an IPv4 and IPv6 raw sockets, starts a reader and routing table routines func Listen(port int, filter BPFFilter) (_ net.PacketConn, err error) { ctx, cancel := context.WithCancel(context.Background()) rawSock := &SharedSocket{ ctx: ctx, cancel: cancel, port: port, packetDemux: make(chan rcvdPacket), } defer func() { if err != nil { if closeErr := rawSock.Close(); closeErr != nil { log.Errorf("Failed to close raw socket: %v", closeErr) } } }() rawSock.router, err = netroute.New() if err != nil { return nil, fmt.Errorf("failed to create raw socket router: %w", err) } rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil) if err != nil { return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err) } if err = nbnet.SetSocketMark(rawSock.conn4); err != nil { return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err) } var sockErr error rawSock.conn6, sockErr = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil) if sockErr != nil { log.Errorf("Failed to create ipv6 raw socket: %v", err) } else { if err = nbnet.SetSocketMark(rawSock.conn6); err != nil { return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err) } } ipv4Instructions, ipv6Instructions, err := filter.GetInstructions(uint32(rawSock.port)) if err != nil { return nil, fmt.Errorf("getBPFInstructions failed with: %w", err) } err = rawSock.conn4.SetBPF(ipv4Instructions) if err != nil { return nil, fmt.Errorf("socket4.SetBPF failed with: %w", err) } if rawSock.conn6 != nil { err = rawSock.conn6.SetBPF(ipv6Instructions) if err != nil { return nil, fmt.Errorf("socket6.SetBPF failed with: %w", err) } } go rawSock.read(rawSock.conn4.Recvfrom) if rawSock.conn6 != nil { go rawSock.read(rawSock.conn6.Recvfrom) } go rawSock.updateRouter() return rawSock, nil } // updateRouter updates the listener routing table client // this is needed to avoid outdated information across different client networks func (s *SharedSocket) updateRouter() { ticker := time.NewTicker(15 * time.Second) defer ticker.Stop() for { select { case <-s.ctx.Done(): return case <-ticker.C: router, err := netroute.New() if err != nil { log.Errorf("Failed to create and update packet router for stunListener: %s", err) continue } s.routerMux.Lock() s.router = router s.routerMux.Unlock() } } } // LocalAddr returns an IPv4 address using the supplied port func (s *SharedSocket) LocalAddr() net.Addr { // todo check impact on ipv6 discovery return &net.UDPAddr{ IP: net.IPv4zero, Port: s.port, } } // SetDeadline sets both the read and write deadlines associated with the ipv4 and ipv6 Conn sockets func (s *SharedSocket) SetDeadline(t time.Time) error { err := s.conn4.SetDeadline(t) if err != nil { return fmt.Errorf("s.conn4.SetDeadline error: %w", err) } if s.conn6 == nil { return nil } err = s.conn6.SetDeadline(t) if err != nil { return fmt.Errorf("s.conn6.SetDeadline error: %w", err) } return nil } // SetReadDeadline sets the read deadline associated with the ipv4 and ipv6 Conn sockets func (s *SharedSocket) SetReadDeadline(t time.Time) error { err := s.conn4.SetReadDeadline(t) if err != nil { return fmt.Errorf("s.conn4.SetReadDeadline error: %w", err) } if s.conn6 == nil { return nil } err = s.conn6.SetReadDeadline(t) if err != nil { return fmt.Errorf("s.conn6.SetReadDeadline error: %w", err) } return nil } // SetWriteDeadline sets the write deadline associated with the ipv4 and ipv6 Conn sockets func (s *SharedSocket) SetWriteDeadline(t time.Time) error { err := s.conn4.SetWriteDeadline(t) if err != nil { return fmt.Errorf("s.conn4.SetWriteDeadline error: %w", err) } if s.conn6 == nil { return nil } err = s.conn6.SetWriteDeadline(t) if err != nil { return fmt.Errorf("s.conn6.SetWriteDeadline error: %w", err) } return nil } // Close closes the underlying ipv4 and ipv6 conn sockets func (s *SharedSocket) Close() error { s.cancel() errGrp := errgroup.Group{} if s.conn4 != nil { errGrp.Go(s.conn4.Close) } if s.conn6 != nil { errGrp.Go(s.conn6.Close) } return errGrp.Wait() } // read start a read loop for a specific receiver and sends the packet to the packetDemux channel func (s *SharedSocket) read(receiver receiver) { for { buf := make([]byte, 1500) n, addr, err := receiver(s.ctx, buf, 0) select { case <-s.ctx.Done(): return case s.packetDemux <- rcvdPacket{n, addr, buf[:n], err}: } } } // ReadFrom reads packets received in the packetDemux channel func (s *SharedSocket) ReadFrom(b []byte) (n int, addr net.Addr, err error) { var pkt rcvdPacket select { case <-s.ctx.Done(): return -1, nil, ErrSharedSockStopped case pkt = <-s.packetDemux: } if pkt.err != nil { return -1, nil, pkt.err } var ip4layer layers.IPv4 var udp layers.UDP var payload gopacket.Payload var parser *gopacket.DecodingLayerParser var ip net.IP if sa, isIPv4 := pkt.addr.(*unix.SockaddrInet4); isIPv4 { ip = sa.Addr[:] parser = gopacket.NewDecodingLayerParser(layers.LayerTypeIPv4, &ip4layer, &udp, &payload) } else if sa, isIPv6 := pkt.addr.(*unix.SockaddrInet6); isIPv6 { ip = sa.Addr[:] parser = gopacket.NewDecodingLayerParser(layers.LayerTypeUDP, &udp, &payload) } else { return -1, nil, fmt.Errorf("received invalid address family") } decodedLayers := make([]gopacket.LayerType, 0, 3) err = parser.DecodeLayers(pkt.buf, &decodedLayers) if err != nil { return 0, nil, err } remoteAddr := &net.UDPAddr{ IP: ip, Port: int(udp.SrcPort), } copy(b, payload) return int(udp.Length), remoteAddr, nil } // WriteTo builds a UDP packet and writes it using the specific IP version writer func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) { rUDPAddr, ok := rAddr.(*net.UDPAddr) if !ok { return -1, fmt.Errorf("invalid address type") } buffer := gopacket.NewSerializeBuffer() payload := gopacket.Payload(buf) udp := &layers.UDP{ SrcPort: layers.UDPPort(s.port), DstPort: layers.UDPPort(rUDPAddr.Port), } s.routerMux.RLock() defer s.routerMux.RUnlock() _, _, src, err := s.router.Route(rUDPAddr.IP) if err != nil { return 0, fmt.Errorf("got an error while checking route, err: %w", err) } rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP) if err := udp.SetNetworkLayerForChecksum(nwLayer); err != nil { return -1, fmt.Errorf("failed to set network layer for checksum: %w", err) } if err := gopacket.SerializeLayers(buffer, writeSerializerOptions, udp, payload); err != nil { return -1, fmt.Errorf("failed serialize rcvdPacket: %w", err) } bufser := buffer.Bytes() return 0, conn.Sendto(context.TODO(), bufser, 0, rSockAddr) } // getWriterObjects returns the specific IP version objects that are used to build a packet and send it using the raw socket func (s *SharedSocket) getWriterObjects(src, dest net.IP) (sa unix.Sockaddr, conn *socket.Conn, layer gopacket.NetworkLayer) { if dest.To4() == nil { sa = &unix.SockaddrInet6{} copy(sa.(*unix.SockaddrInet6).Addr[:], dest.To16()) conn = s.conn6 layer = &layers.IPv6{ SrcIP: src, DstIP: dest, } } else { sa = &unix.SockaddrInet4{} copy(sa.(*unix.SockaddrInet4).Addr[:], dest.To4()) conn = s.conn4 layer = &layers.IPv4{ Version: 4, TTL: 64, Protocol: layers.IPProtocolUDP, SrcIP: src, DstIP: dest, } } return sa, conn, layer }