mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-18 19:09:09 +02:00
Merge branch 'main' into peers-get-account-refactoring
# Conflicts: # management/server/sql_store_test.go
This commit is contained in:
12
client/iface/bind/control_android.go
Normal file
12
client/iface/bind/control_android.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// ControlFns is not thread safe and should only be modified during init.
|
||||
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
|
||||
}
|
@@ -55,7 +55,7 @@ type ruleParams struct {
|
||||
|
||||
// isLegacy determines whether to use the legacy routing setup
|
||||
func isLegacy() bool {
|
||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true"
|
||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark()
|
||||
}
|
||||
|
||||
// setIsLegacy sets the legacy routing setup
|
||||
|
2
go.mod
2
go.mod
@@ -236,7 +236,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
||||
|
||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73
|
||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9
|
||||
|
||||
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
|
||||
|
4
go.sum
4
go.sum
@@ -527,8 +527,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 h1:jayg97LH/jJlvpIHVxueTfa+tfQ+FY8fy2sIhCwkz0g=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY=
|
||||
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||
|
@@ -2045,139 +2045,3 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
require.Nil(t, nsGroup)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountPeers(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "retrieve peers by existing account ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 4,
|
||||
},
|
||||
{
|
||||
name: "non-existing account ID",
|
||||
accountID: "nonexistent",
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "empty account ID",
|
||||
accountID: "",
|
||||
expectedCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, tt.accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peers, tt.expectedCount)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "retrieve peers with expiration by existing account ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 1,
|
||||
},
|
||||
{
|
||||
name: "non-existing account ID",
|
||||
accountID: "nonexistent",
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "empty account ID",
|
||||
accountID: "",
|
||||
expectedCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthShare, tt.accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peers, tt.expectedCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "retrieve peers with inactivity by existing account ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 1,
|
||||
},
|
||||
{
|
||||
name: "non-existing account ID",
|
||||
accountID: "nonexistent",
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "empty account ID",
|
||||
accountID: "",
|
||||
expectedCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
peers, err := store.GetAccountPeersWithInactivity(context.Background(), LockingStrengthShare, tt.accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peers, tt.expectedCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAllEphemeralPeers(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/storev1.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
peers, err := store.GetAllEphemeralPeers(context.Background(), LockingStrengthShare)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peers, 1)
|
||||
require.True(t, peers[0].Ephemeral)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeletePeer(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
peerID := "csrnkiq7qv9d8aitqd50"
|
||||
|
||||
err = store.DeletePeer(context.Background(), LockingStrengthUpdate, accountID, peerID)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, peer)
|
||||
}
|
||||
|
31
util/net/conn.go
Normal file
31
util/net/conn.go
Normal file
@@ -0,0 +1,31 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Conn wraps a net.Conn to override the Close method
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
ID ConnectionID
|
||||
}
|
||||
|
||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
||||
func (c *Conn) Close() error {
|
||||
err := c.Conn.Close()
|
||||
|
||||
dialerCloseHooksMutex.RLock()
|
||||
defer dialerCloseHooksMutex.RUnlock()
|
||||
|
||||
for _, hook := range dialerCloseHooks {
|
||||
if err := hook(c.ID, &c.Conn); err != nil {
|
||||
log.Errorf("Error executing dialer close hook: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
58
util/net/dial.go
Normal file
58
util/net/dial.go
Normal file
@@ -0,0 +1,58 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.DialUDP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
conn, err := dialer.Dial(network, raddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
||||
}
|
||||
|
||||
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
|
||||
if !ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
|
||||
}
|
||||
|
||||
return udpConn, nil
|
||||
}
|
||||
|
||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.DialTCP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
conn, err := dialer.Dial(network, raddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
||||
}
|
||||
|
||||
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
|
||||
}
|
||||
|
||||
return tcpConn, nil
|
||||
}
|
@@ -1,25 +0,0 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (d *Dialer) init() {
|
||||
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
|
||||
err := c.Control(func(fd uintptr) {
|
||||
androidProtectSocketLock.Lock()
|
||||
f := androidProtectSocket
|
||||
androidProtectSocketLock.Unlock()
|
||||
if f == nil {
|
||||
return
|
||||
}
|
||||
ok := f(int32(fd))
|
||||
if !ok {
|
||||
log.Errorf("failed to protect socket: %d", fd)
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
}
|
@@ -81,28 +81,6 @@ func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
||||
return d.DialContext(context.Background(), network, address)
|
||||
}
|
||||
|
||||
// Conn wraps a net.Conn to override the Close method
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
ID ConnectionID
|
||||
}
|
||||
|
||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
||||
func (c *Conn) Close() error {
|
||||
err := c.Conn.Close()
|
||||
|
||||
dialerCloseHooksMutex.RLock()
|
||||
defer dialerCloseHooksMutex.RUnlock()
|
||||
|
||||
for _, hook := range dialerCloseHooks {
|
||||
if err := hook(c.ID, &c.Conn); err != nil {
|
||||
log.Errorf("Error executing dialer close hook: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
@@ -127,51 +105,3 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r
|
||||
|
||||
return result.ErrorOrNil()
|
||||
}
|
||||
|
||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.DialUDP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
conn, err := dialer.Dial(network, raddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
||||
}
|
||||
|
||||
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
|
||||
if !ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
|
||||
}
|
||||
|
||||
return udpConn, nil
|
||||
}
|
||||
|
||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.DialTCP(network, laddr, raddr)
|
||||
}
|
||||
|
||||
dialer := NewDialer()
|
||||
dialer.LocalAddr = laddr
|
||||
|
||||
conn, err := dialer.Dial(network, raddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
||||
}
|
||||
|
||||
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
|
||||
}
|
||||
|
||||
return tcpConn, nil
|
||||
}
|
5
util/net/dialer_init_android.go
Normal file
5
util/net/dialer_init_android.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package net
|
||||
|
||||
func (d *Dialer) init() {
|
||||
d.Dialer.Control = ControlProtectSocket
|
||||
}
|
@@ -7,6 +7,6 @@ import "syscall"
|
||||
// init configures the net.Dialer Control function to set the fwmark on the socket
|
||||
func (d *Dialer) init() {
|
||||
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
|
||||
return SetRawSocketMark(c)
|
||||
return setRawSocketMark(c)
|
||||
}
|
||||
}
|
@@ -3,4 +3,5 @@
|
||||
package net
|
||||
|
||||
func (d *Dialer) init() {
|
||||
// implemented on Linux and Android only
|
||||
}
|
29
util/net/env.go
Normal file
29
util/net/env.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
)
|
||||
|
||||
const (
|
||||
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
||||
envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
|
||||
)
|
||||
|
||||
func CustomRoutingDisabled() bool {
|
||||
if netstack.IsEnabled() {
|
||||
return true
|
||||
}
|
||||
return os.Getenv(envDisableCustomRouting) == "true"
|
||||
}
|
||||
|
||||
func SkipSocketMark() bool {
|
||||
if skipSocketMark := os.Getenv(envSkipSocketMark); skipSocketMark == "true" {
|
||||
log.Infof("%s is set to true, skipping SO_MARK", envSkipSocketMark)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
37
util/net/listen.go
Normal file
37
util/net/listen.go
Normal file
@@ -0,0 +1,37 @@
|
||||
//go:build !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ListenUDP listens on the network address and returns a transport.UDPConn
|
||||
// which includes support for write and close hooks.
|
||||
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.ListenUDP(network, laddr)
|
||||
}
|
||||
|
||||
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen UDP: %w", err)
|
||||
}
|
||||
|
||||
packetConn := conn.(*PacketConn)
|
||||
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
|
||||
if !ok {
|
||||
if err := packetConn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
|
||||
}
|
||||
|
||||
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
|
||||
}
|
@@ -1,26 +0,0 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
||||
func (l *ListenerConfig) init() {
|
||||
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
|
||||
err := c.Control(func(fd uintptr) {
|
||||
androidProtectSocketLock.Lock()
|
||||
f := androidProtectSocket
|
||||
androidProtectSocketLock.Unlock()
|
||||
if f == nil {
|
||||
return
|
||||
}
|
||||
ok := f(int32(fd))
|
||||
if !ok {
|
||||
log.Errorf("failed to protect listener socket: %d", fd)
|
||||
}
|
||||
})
|
||||
return err
|
||||
}
|
||||
}
|
6
util/net/listener_init_android.go
Normal file
6
util/net/listener_init_android.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package net
|
||||
|
||||
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
||||
func (l *ListenerConfig) init() {
|
||||
l.ListenConfig.Control = ControlProtectSocket
|
||||
}
|
@@ -9,6 +9,6 @@ import (
|
||||
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
||||
func (l *ListenerConfig) init() {
|
||||
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
|
||||
return SetRawSocketMark(c)
|
||||
return setRawSocketMark(c)
|
||||
}
|
||||
}
|
@@ -3,4 +3,5 @@
|
||||
package net
|
||||
|
||||
func (l *ListenerConfig) init() {
|
||||
// implemented on Linux and Android only
|
||||
}
|
@@ -8,7 +8,6 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -146,27 +145,3 @@ func closeConn(id ConnectionID, conn net.PacketConn) error {
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ListenUDP listens on the network address and returns a transport.UDPConn
|
||||
// which includes support for write and close hooks.
|
||||
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||
if CustomRoutingDisabled() {
|
||||
return net.ListenUDP(network, laddr)
|
||||
}
|
||||
|
||||
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen UDP: %w", err)
|
||||
}
|
||||
|
||||
packetConn := conn.(*PacketConn)
|
||||
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
|
||||
if !ok {
|
||||
if err := packetConn.Close(); err != nil {
|
||||
log.Errorf("Failed to close connection: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
|
||||
}
|
||||
|
||||
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
|
||||
}
|
@@ -2,9 +2,6 @@ package net
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -16,8 +13,6 @@ const (
|
||||
PreroutingFwmarkRedirected = 0x1BD01
|
||||
PreroutingFwmarkMasquerade = 0x1BD11
|
||||
PreroutingFwmarkMasqueradeReturn = 0x1BD12
|
||||
|
||||
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
||||
)
|
||||
|
||||
// ConnectionID provides a globally unique identifier for network connections.
|
||||
@@ -31,10 +26,3 @@ type RemoveHookFunc func(connID ConnectionID) error
|
||||
func GenerateConnID() ConnectionID {
|
||||
return ConnectionID(uuid.NewString())
|
||||
}
|
||||
|
||||
func CustomRoutingDisabled() bool {
|
||||
if netstack.IsEnabled() {
|
||||
return true
|
||||
}
|
||||
return os.Getenv(envDisableCustomRouting) == "true"
|
||||
}
|
||||
|
@@ -4,29 +4,42 @@ package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK"
|
||||
|
||||
// SetSocketMark sets the SO_MARK option on the given socket connection
|
||||
func SetSocketMark(conn syscall.Conn) error {
|
||||
if isSocketMarkDisabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
sysconn, err := conn.SyscallConn()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get raw conn: %w", err)
|
||||
}
|
||||
|
||||
return SetRawSocketMark(sysconn)
|
||||
return setRawSocketMark(sysconn)
|
||||
}
|
||||
|
||||
func SetRawSocketMark(conn syscall.RawConn) error {
|
||||
// SetSocketOpt sets the SO_MARK option on the given file descriptor
|
||||
func SetSocketOpt(fd int) error {
|
||||
if isSocketMarkDisabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return setSocketOptInt(fd)
|
||||
}
|
||||
|
||||
func setRawSocketMark(conn syscall.RawConn) error {
|
||||
var setErr error
|
||||
|
||||
err := conn.Control(func(fd uintptr) {
|
||||
setErr = SetSocketOpt(int(fd))
|
||||
if isSocketMarkDisabled() {
|
||||
return
|
||||
}
|
||||
setErr = setSocketOptInt(int(fd))
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("control: %w", err)
|
||||
@@ -39,17 +52,18 @@ func SetRawSocketMark(conn syscall.RawConn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetSocketOpt(fd int) error {
|
||||
if CustomRoutingDisabled() {
|
||||
log.Infof("Custom routing is disabled, skipping SO_MARK")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for the new environment variable
|
||||
if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" {
|
||||
log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK")
|
||||
return nil
|
||||
}
|
||||
|
||||
func setSocketOptInt(fd int) error {
|
||||
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
|
||||
}
|
||||
|
||||
func isSocketMarkDisabled() bool {
|
||||
if CustomRoutingDisabled() {
|
||||
log.Infof("Custom routing is disabled, skipping SO_MARK")
|
||||
return true
|
||||
}
|
||||
|
||||
if SkipSocketMark() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@@ -1,14 +1,42 @@
|
||||
package net
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
var (
|
||||
androidProtectSocketLock sync.Mutex
|
||||
androidProtectSocket func(fd int32) bool
|
||||
)
|
||||
|
||||
func SetAndroidProtectSocketFn(f func(fd int32) bool) {
|
||||
func SetAndroidProtectSocketFn(fn func(fd int32) bool) {
|
||||
androidProtectSocketLock.Lock()
|
||||
androidProtectSocket = f
|
||||
androidProtectSocket = fn
|
||||
androidProtectSocketLock.Unlock()
|
||||
}
|
||||
|
||||
// ControlProtectSocket is a Control function that sets the fwmark on the socket
|
||||
func ControlProtectSocket(_, _ string, c syscall.RawConn) error {
|
||||
var aErr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
androidProtectSocketLock.Lock()
|
||||
defer androidProtectSocketLock.Unlock()
|
||||
|
||||
if androidProtectSocket == nil {
|
||||
aErr = fmt.Errorf("socket protection function not set")
|
||||
return
|
||||
}
|
||||
|
||||
if !androidProtectSocket(int32(fd)) {
|
||||
aErr = fmt.Errorf("failed to protect socket via Android")
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return aErr
|
||||
}
|
||||
|
Reference in New Issue
Block a user