mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-12 05:38:17 +01:00
[client] Code cleaning in net pkg and fix exit node feature on Android(#2932)
Code cleaning around the util/net package. The goal was to write a more understandable source code but modify nothing on the logic. Protect the WireGuard UDP listeners with marks. The implementation can support the VPN permission revocation events in thread safe way. It will be important if we start to support the running time route and DNS update features. - uniformize the file name convention: [struct_name] _ [functions] _ [os].go - code cleaning in net_linux.go - move env variables to env.go file
This commit is contained in:
parent
9683da54b0
commit
9203690033
12
client/iface/bind/control_android.go
Normal file
12
client/iface/bind/control_android.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// ControlFns is not thread safe and should only be modified during init.
|
||||||
|
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
|
||||||
|
}
|
@ -55,7 +55,7 @@ type ruleParams struct {
|
|||||||
|
|
||||||
// isLegacy determines whether to use the legacy routing setup
|
// isLegacy determines whether to use the legacy routing setup
|
||||||
func isLegacy() bool {
|
func isLegacy() bool {
|
||||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true"
|
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark()
|
||||||
}
|
}
|
||||||
|
|
||||||
// setIsLegacy sets the legacy routing setup
|
// setIsLegacy sets the legacy routing setup
|
||||||
|
2
go.mod
2
go.mod
@ -236,7 +236,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
|||||||
|
|
||||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||||
|
|
||||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73
|
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9
|
||||||
|
|
||||||
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||||
|
|
||||||
|
4
go.sum
4
go.sum
@ -527,8 +527,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
|||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 h1:jayg97LH/jJlvpIHVxueTfa+tfQ+FY8fy2sIhCwkz0g=
|
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
||||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||||
|
31
util/net/conn.go
Normal file
31
util/net/conn.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Conn wraps a net.Conn to override the Close method
|
||||||
|
type Conn struct {
|
||||||
|
net.Conn
|
||||||
|
ID ConnectionID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
err := c.Conn.Close()
|
||||||
|
|
||||||
|
dialerCloseHooksMutex.RLock()
|
||||||
|
defer dialerCloseHooksMutex.RUnlock()
|
||||||
|
|
||||||
|
for _, hook := range dialerCloseHooks {
|
||||||
|
if err := hook(c.ID, &c.Conn); err != nil {
|
||||||
|
log.Errorf("Error executing dialer close hook: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
58
util/net/dial.go
Normal file
58
util/net/dial.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return net.DialUDP(network, laddr, raddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := NewDialer()
|
||||||
|
dialer.LocalAddr = laddr
|
||||||
|
|
||||||
|
conn, err := dialer.Dial(network, raddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
|
||||||
|
if !ok {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return udpConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return net.DialTCP(network, laddr, raddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := NewDialer()
|
||||||
|
dialer.LocalAddr = laddr
|
||||||
|
|
||||||
|
conn, err := dialer.Dial(network, raddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
|
||||||
|
if !ok {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tcpConn, nil
|
||||||
|
}
|
@ -1,25 +0,0 @@
|
|||||||
package net
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (d *Dialer) init() {
|
|
||||||
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
|
|
||||||
err := c.Control(func(fd uintptr) {
|
|
||||||
androidProtectSocketLock.Lock()
|
|
||||||
f := androidProtectSocket
|
|
||||||
androidProtectSocketLock.Unlock()
|
|
||||||
if f == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ok := f(int32(fd))
|
|
||||||
if !ok {
|
|
||||||
log.Errorf("failed to protect socket: %d", fd)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
@ -81,28 +81,6 @@ func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
|||||||
return d.DialContext(context.Background(), network, address)
|
return d.DialContext(context.Background(), network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conn wraps a net.Conn to override the Close method
|
|
||||||
type Conn struct {
|
|
||||||
net.Conn
|
|
||||||
ID ConnectionID
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
|
||||||
func (c *Conn) Close() error {
|
|
||||||
err := c.Conn.Close()
|
|
||||||
|
|
||||||
dialerCloseHooksMutex.RLock()
|
|
||||||
defer dialerCloseHooksMutex.RUnlock()
|
|
||||||
|
|
||||||
for _, hook := range dialerCloseHooks {
|
|
||||||
if err := hook(c.ID, &c.Conn); err != nil {
|
|
||||||
log.Errorf("Error executing dialer close hook: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
|
func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
|
||||||
host, _, err := net.SplitHostPort(address)
|
host, _, err := net.SplitHostPort(address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -127,51 +105,3 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r
|
|||||||
|
|
||||||
return result.ErrorOrNil()
|
return result.ErrorOrNil()
|
||||||
}
|
}
|
||||||
|
|
||||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
|
||||||
if CustomRoutingDisabled() {
|
|
||||||
return net.DialUDP(network, laddr, raddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
dialer := NewDialer()
|
|
||||||
dialer.LocalAddr = laddr
|
|
||||||
|
|
||||||
conn, err := dialer.Dial(network, raddr.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
|
|
||||||
if !ok {
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
log.Errorf("Failed to close connection: %v", err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
return udpConn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
|
||||||
if CustomRoutingDisabled() {
|
|
||||||
return net.DialTCP(network, laddr, raddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
dialer := NewDialer()
|
|
||||||
dialer.LocalAddr = laddr
|
|
||||||
|
|
||||||
conn, err := dialer.Dial(network, raddr.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
|
|
||||||
if !ok {
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
log.Errorf("Failed to close connection: %v", err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tcpConn, nil
|
|
||||||
}
|
|
5
util/net/dialer_init_android.go
Normal file
5
util/net/dialer_init_android.go
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
func (d *Dialer) init() {
|
||||||
|
d.Dialer.Control = ControlProtectSocket
|
||||||
|
}
|
@ -7,6 +7,6 @@ import "syscall"
|
|||||||
// init configures the net.Dialer Control function to set the fwmark on the socket
|
// init configures the net.Dialer Control function to set the fwmark on the socket
|
||||||
func (d *Dialer) init() {
|
func (d *Dialer) init() {
|
||||||
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
|
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
|
||||||
return SetRawSocketMark(c)
|
return setRawSocketMark(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -3,4 +3,5 @@
|
|||||||
package net
|
package net
|
||||||
|
|
||||||
func (d *Dialer) init() {
|
func (d *Dialer) init() {
|
||||||
|
// implemented on Linux and Android only
|
||||||
}
|
}
|
29
util/net/env.go
Normal file
29
util/net/env.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
||||||
|
envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CustomRoutingDisabled() bool {
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return os.Getenv(envDisableCustomRouting) == "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
func SkipSocketMark() bool {
|
||||||
|
if skipSocketMark := os.Getenv(envSkipSocketMark); skipSocketMark == "true" {
|
||||||
|
log.Infof("%s is set to true, skipping SO_MARK", envSkipSocketMark)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
37
util/net/listen.go
Normal file
37
util/net/listen.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pion/transport/v3"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ListenUDP listens on the network address and returns a transport.UDPConn
|
||||||
|
// which includes support for write and close hooks.
|
||||||
|
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return net.ListenUDP(network, laddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("listen UDP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
packetConn := conn.(*PacketConn)
|
||||||
|
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
|
||||||
|
if !ok {
|
||||||
|
if err := packetConn.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
|
||||||
|
}
|
@ -1,26 +0,0 @@
|
|||||||
package net
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
|
||||||
func (l *ListenerConfig) init() {
|
|
||||||
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
|
|
||||||
err := c.Control(func(fd uintptr) {
|
|
||||||
androidProtectSocketLock.Lock()
|
|
||||||
f := androidProtectSocket
|
|
||||||
androidProtectSocketLock.Unlock()
|
|
||||||
if f == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ok := f(int32(fd))
|
|
||||||
if !ok {
|
|
||||||
log.Errorf("failed to protect listener socket: %d", fd)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
6
util/net/listener_init_android.go
Normal file
6
util/net/listener_init_android.go
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
||||||
|
func (l *ListenerConfig) init() {
|
||||||
|
l.ListenConfig.Control = ControlProtectSocket
|
||||||
|
}
|
@ -9,6 +9,6 @@ import (
|
|||||||
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
||||||
func (l *ListenerConfig) init() {
|
func (l *ListenerConfig) init() {
|
||||||
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
|
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
|
||||||
return SetRawSocketMark(c)
|
return setRawSocketMark(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -3,4 +3,5 @@
|
|||||||
package net
|
package net
|
||||||
|
|
||||||
func (l *ListenerConfig) init() {
|
func (l *ListenerConfig) init() {
|
||||||
|
// implemented on Linux and Android only
|
||||||
}
|
}
|
@ -8,7 +8,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/pion/transport/v3"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -146,27 +145,3 @@ func closeConn(id ConnectionID, conn net.PacketConn) error {
|
|||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenUDP listens on the network address and returns a transport.UDPConn
|
|
||||||
// which includes support for write and close hooks.
|
|
||||||
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
|
|
||||||
if CustomRoutingDisabled() {
|
|
||||||
return net.ListenUDP(network, laddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("listen UDP: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
packetConn := conn.(*PacketConn)
|
|
||||||
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
|
|
||||||
if !ok {
|
|
||||||
if err := packetConn.Close(); err != nil {
|
|
||||||
log.Errorf("Failed to close connection: %v", err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
|
|
||||||
}
|
|
@ -2,9 +2,6 @@ package net
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
@ -16,8 +13,6 @@ const (
|
|||||||
PreroutingFwmarkRedirected = 0x1BD01
|
PreroutingFwmarkRedirected = 0x1BD01
|
||||||
PreroutingFwmarkMasquerade = 0x1BD11
|
PreroutingFwmarkMasquerade = 0x1BD11
|
||||||
PreroutingFwmarkMasqueradeReturn = 0x1BD12
|
PreroutingFwmarkMasqueradeReturn = 0x1BD12
|
||||||
|
|
||||||
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionID provides a globally unique identifier for network connections.
|
// ConnectionID provides a globally unique identifier for network connections.
|
||||||
@ -31,10 +26,3 @@ type RemoveHookFunc func(connID ConnectionID) error
|
|||||||
func GenerateConnID() ConnectionID {
|
func GenerateConnID() ConnectionID {
|
||||||
return ConnectionID(uuid.NewString())
|
return ConnectionID(uuid.NewString())
|
||||||
}
|
}
|
||||||
|
|
||||||
func CustomRoutingDisabled() bool {
|
|
||||||
if netstack.IsEnabled() {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return os.Getenv(envDisableCustomRouting) == "true"
|
|
||||||
}
|
|
||||||
|
@ -4,29 +4,42 @@ package net
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK"
|
|
||||||
|
|
||||||
// SetSocketMark sets the SO_MARK option on the given socket connection
|
// SetSocketMark sets the SO_MARK option on the given socket connection
|
||||||
func SetSocketMark(conn syscall.Conn) error {
|
func SetSocketMark(conn syscall.Conn) error {
|
||||||
|
if isSocketMarkDisabled() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
sysconn, err := conn.SyscallConn()
|
sysconn, err := conn.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get raw conn: %w", err)
|
return fmt.Errorf("get raw conn: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return SetRawSocketMark(sysconn)
|
return setRawSocketMark(sysconn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetRawSocketMark(conn syscall.RawConn) error {
|
// SetSocketOpt sets the SO_MARK option on the given file descriptor
|
||||||
|
func SetSocketOpt(fd int) error {
|
||||||
|
if isSocketMarkDisabled() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return setSocketOptInt(fd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setRawSocketMark(conn syscall.RawConn) error {
|
||||||
var setErr error
|
var setErr error
|
||||||
|
|
||||||
err := conn.Control(func(fd uintptr) {
|
err := conn.Control(func(fd uintptr) {
|
||||||
setErr = SetSocketOpt(int(fd))
|
if isSocketMarkDisabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
setErr = setSocketOptInt(int(fd))
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("control: %w", err)
|
return fmt.Errorf("control: %w", err)
|
||||||
@ -39,17 +52,18 @@ func SetRawSocketMark(conn syscall.RawConn) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetSocketOpt(fd int) error {
|
func setSocketOptInt(fd int) error {
|
||||||
if CustomRoutingDisabled() {
|
|
||||||
log.Infof("Custom routing is disabled, skipping SO_MARK")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for the new environment variable
|
|
||||||
if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" {
|
|
||||||
log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
|
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isSocketMarkDisabled() bool {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
log.Infof("Custom routing is disabled, skipping SO_MARK")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if SkipSocketMark() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@ -1,14 +1,42 @@
|
|||||||
package net
|
package net
|
||||||
|
|
||||||
import "sync"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
androidProtectSocketLock sync.Mutex
|
androidProtectSocketLock sync.Mutex
|
||||||
androidProtectSocket func(fd int32) bool
|
androidProtectSocket func(fd int32) bool
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetAndroidProtectSocketFn(f func(fd int32) bool) {
|
func SetAndroidProtectSocketFn(fn func(fd int32) bool) {
|
||||||
androidProtectSocketLock.Lock()
|
androidProtectSocketLock.Lock()
|
||||||
androidProtectSocket = f
|
androidProtectSocket = fn
|
||||||
androidProtectSocketLock.Unlock()
|
androidProtectSocketLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ControlProtectSocket is a Control function that sets the fwmark on the socket
|
||||||
|
func ControlProtectSocket(_, _ string, c syscall.RawConn) error {
|
||||||
|
var aErr error
|
||||||
|
err := c.Control(func(fd uintptr) {
|
||||||
|
androidProtectSocketLock.Lock()
|
||||||
|
defer androidProtectSocketLock.Unlock()
|
||||||
|
|
||||||
|
if androidProtectSocket == nil {
|
||||||
|
aErr = fmt.Errorf("socket protection function not set")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !androidProtectSocket(int32(fd)) {
|
||||||
|
aErr = fmt.Errorf("failed to protect socket via Android")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return aErr
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user