Support client default routes for Linux (#1667)

All routes are now installed in a custom netbird routing table.
Management and wireguard traffic is now marked with a custom fwmark.
When the mark is present the traffic is routed via the main routing table, bypassing the VPN.
When the mark is absent the traffic is routed via the netbird routing table, if:
- there's no match in the main routing table
- it would match the default route in the routing table

IPv6 traffic is blocked when a default route IPv4 route is configured to avoid leakage.
This commit is contained in:
Viktor Liu
2024-03-21 16:49:28 +01:00
committed by GitHub
parent 846871913d
commit 2475473227
41 changed files with 1656 additions and 376 deletions

View File

@ -21,6 +21,8 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
"golang.org/x/sys/unix"
nbnet "github.com/netbirdio/netbird/util/net"
)
// ErrSharedSockStopped indicates that shared socket has been stopped
@ -55,8 +57,7 @@ var writeSerializerOptions = gopacket.SerializeOptions{
}
// Listen creates an IPv4 and IPv6 raw sockets, starts a reader and routing table routines
func Listen(port int, filter BPFFilter) (net.PacketConn, error) {
var err error
func Listen(port int, filter BPFFilter) (_ net.PacketConn, err error) {
ctx, cancel := context.WithCancel(context.Background())
rawSock := &SharedSocket{
ctx: ctx,
@ -65,37 +66,51 @@ func Listen(port int, filter BPFFilter) (net.PacketConn, error) {
packetDemux: make(chan rcvdPacket),
}
defer func() {
if err != nil {
if closeErr := rawSock.Close(); closeErr != nil {
log.Errorf("Failed to close raw socket: %v", closeErr)
}
}
}()
rawSock.router, err = netroute.New()
if err != nil {
return nil, fmt.Errorf("failed to create raw socket router: %v", err)
return nil, fmt.Errorf("failed to create raw socket router: %w", err)
}
rawSock.conn4, err = socket.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp4", nil)
if err != nil {
return nil, fmt.Errorf("failed to create ipv4 raw socket: %v", err)
return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err)
}
rawSock.conn6, err = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil)
if err != nil {
log.Errorf("failed to create ipv6 raw socket: %v", err)
if err = nbnet.SetSocketMark(rawSock.conn4); err != nil {
return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err)
}
var sockErr error
rawSock.conn6, sockErr = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil)
if sockErr != nil {
log.Errorf("Failed to create ipv6 raw socket: %v", err)
} else {
if err = nbnet.SetSocketMark(rawSock.conn6); err != nil {
return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err)
}
}
ipv4Instructions, ipv6Instructions, err := filter.GetInstructions(uint32(rawSock.port))
if err != nil {
_ = rawSock.Close()
return nil, fmt.Errorf("getBPFInstructions failed with: %rawSock", err)
return nil, fmt.Errorf("getBPFInstructions failed with: %w", err)
}
err = rawSock.conn4.SetBPF(ipv4Instructions)
if err != nil {
_ = rawSock.Close()
return nil, fmt.Errorf("socket4.SetBPF failed with: %rawSock", err)
return nil, fmt.Errorf("socket4.SetBPF failed with: %w", err)
}
if rawSock.conn6 != nil {
err = rawSock.conn6.SetBPF(ipv6Instructions)
if err != nil {
_ = rawSock.Close()
return nil, fmt.Errorf("socket6.SetBPF failed with: %rawSock", err)
return nil, fmt.Errorf("socket6.SetBPF failed with: %w", err)
}
}
@ -121,7 +136,7 @@ func (s *SharedSocket) updateRouter() {
case <-ticker.C:
router, err := netroute.New()
if err != nil {
log.Errorf("failed to create and update packet router for stunListener: %s", err)
log.Errorf("Failed to create and update packet router for stunListener: %s", err)
continue
}
s.routerMux.Lock()
@ -144,7 +159,7 @@ func (s *SharedSocket) LocalAddr() net.Addr {
func (s *SharedSocket) SetDeadline(t time.Time) error {
err := s.conn4.SetDeadline(t)
if err != nil {
return fmt.Errorf("s.conn4.SetDeadline error: %s", err)
return fmt.Errorf("s.conn4.SetDeadline error: %w", err)
}
if s.conn6 == nil {
return nil
@ -152,7 +167,7 @@ func (s *SharedSocket) SetDeadline(t time.Time) error {
err = s.conn6.SetDeadline(t)
if err != nil {
return fmt.Errorf("s.conn6.SetDeadline error: %s", err)
return fmt.Errorf("s.conn6.SetDeadline error: %w", err)
}
return nil
}
@ -161,7 +176,7 @@ func (s *SharedSocket) SetDeadline(t time.Time) error {
func (s *SharedSocket) SetReadDeadline(t time.Time) error {
err := s.conn4.SetReadDeadline(t)
if err != nil {
return fmt.Errorf("s.conn4.SetReadDeadline error: %s", err)
return fmt.Errorf("s.conn4.SetReadDeadline error: %w", err)
}
if s.conn6 == nil {
return nil
@ -169,7 +184,7 @@ func (s *SharedSocket) SetReadDeadline(t time.Time) error {
err = s.conn6.SetReadDeadline(t)
if err != nil {
return fmt.Errorf("s.conn6.SetReadDeadline error: %s", err)
return fmt.Errorf("s.conn6.SetReadDeadline error: %w", err)
}
return nil
}
@ -178,7 +193,7 @@ func (s *SharedSocket) SetReadDeadline(t time.Time) error {
func (s *SharedSocket) SetWriteDeadline(t time.Time) error {
err := s.conn4.SetWriteDeadline(t)
if err != nil {
return fmt.Errorf("s.conn4.SetWriteDeadline error: %s", err)
return fmt.Errorf("s.conn4.SetWriteDeadline error: %w", err)
}
if s.conn6 == nil {
return nil
@ -186,7 +201,7 @@ func (s *SharedSocket) SetWriteDeadline(t time.Time) error {
err = s.conn6.SetWriteDeadline(t)
if err != nil {
return fmt.Errorf("s.conn6.SetWriteDeadline error: %s", err)
return fmt.Errorf("s.conn6.SetWriteDeadline error: %w", err)
}
return nil
}
@ -282,7 +297,7 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
_, _, src, err := s.router.Route(rUDPAddr.IP)
if err != nil {
return 0, fmt.Errorf("got an error while checking route, err: %s", err)
return 0, fmt.Errorf("got an error while checking route, err: %w", err)
}
rSockAddr, conn, nwLayer := s.getWriterObjects(src, rUDPAddr.IP)
@ -292,7 +307,7 @@ func (s *SharedSocket) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
}
if err := gopacket.SerializeLayers(buffer, writeSerializerOptions, udp, payload); err != nil {
return -1, fmt.Errorf("failed serialize rcvdPacket: %s", err)
return -1, fmt.Errorf("failed serialize rcvdPacket: %w", err)
}
bufser := buffer.Bytes()