mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-21 11:41:02 +02:00
Merge branch 'main' into peers-get-account-refactoring
# Conflicts: # management/server/sql_store_test.go
This commit is contained in:
12
client/iface/bind/control_android.go
Normal file
12
client/iface/bind/control_android.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// ControlFns is not thread safe and should only be modified during init.
|
||||||
|
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
|
||||||
|
}
|
@@ -55,7 +55,7 @@ type ruleParams struct {
|
|||||||
|
|
||||||
// isLegacy determines whether to use the legacy routing setup
|
// isLegacy determines whether to use the legacy routing setup
|
||||||
func isLegacy() bool {
|
func isLegacy() bool {
|
||||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true"
|
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark()
|
||||||
}
|
}
|
||||||
|
|
||||||
// setIsLegacy sets the legacy routing setup
|
// setIsLegacy sets the legacy routing setup
|
||||||
|
2
go.mod
2
go.mod
@@ -236,7 +236,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
|||||||
|
|
||||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||||
|
|
||||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73
|
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9
|
||||||
|
|
||||||
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||||
|
|
||||||
|
4
go.sum
4
go.sum
@@ -527,8 +527,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
|||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 h1:jayg97LH/jJlvpIHVxueTfa+tfQ+FY8fy2sIhCwkz0g=
|
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
||||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||||
|
@@ -2045,139 +2045,3 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) {
|
|||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, nsGroup)
|
require.Nil(t, nsGroup)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSqlStore_GetAccountPeers(t *testing.T) {
|
|
||||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir())
|
|
||||||
t.Cleanup(cleanup)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
accountID string
|
|
||||||
expectedCount int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "retrieve peers by existing account ID",
|
|
||||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
|
||||||
expectedCount: 4,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-existing account ID",
|
|
||||||
accountID: "nonexistent",
|
|
||||||
expectedCount: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty account ID",
|
|
||||||
accountID: "",
|
|
||||||
expectedCount: 0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, tt.accountID)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, peers, tt.expectedCount)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) {
|
|
||||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir())
|
|
||||||
t.Cleanup(cleanup)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
accountID string
|
|
||||||
expectedCount int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "retrieve peers with expiration by existing account ID",
|
|
||||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
|
||||||
expectedCount: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-existing account ID",
|
|
||||||
accountID: "nonexistent",
|
|
||||||
expectedCount: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty account ID",
|
|
||||||
accountID: "",
|
|
||||||
expectedCount: 0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthShare, tt.accountID)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, peers, tt.expectedCount)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) {
|
|
||||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir())
|
|
||||||
t.Cleanup(cleanup)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
accountID string
|
|
||||||
expectedCount int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "retrieve peers with inactivity by existing account ID",
|
|
||||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
|
||||||
expectedCount: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-existing account ID",
|
|
||||||
accountID: "nonexistent",
|
|
||||||
expectedCount: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty account ID",
|
|
||||||
accountID: "",
|
|
||||||
expectedCount: 0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
peers, err := store.GetAccountPeersWithInactivity(context.Background(), LockingStrengthShare, tt.accountID)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, peers, tt.expectedCount)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqlStore_GetAllEphemeralPeers(t *testing.T) {
|
|
||||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/storev1.sql", t.TempDir())
|
|
||||||
t.Cleanup(cleanup)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
peers, err := store.GetAllEphemeralPeers(context.Background(), LockingStrengthShare)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, peers, 1)
|
|
||||||
require.True(t, peers[0].Ephemeral)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSqlStore_DeletePeer(t *testing.T) {
|
|
||||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir())
|
|
||||||
t.Cleanup(cleanup)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
|
||||||
peerID := "csrnkiq7qv9d8aitqd50"
|
|
||||||
|
|
||||||
err = store.DeletePeer(context.Background(), LockingStrengthUpdate, accountID, peerID)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID)
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Nil(t, peer)
|
|
||||||
}
|
|
||||||
|
31
util/net/conn.go
Normal file
31
util/net/conn.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Conn wraps a net.Conn to override the Close method
|
||||||
|
type Conn struct {
|
||||||
|
net.Conn
|
||||||
|
ID ConnectionID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
err := c.Conn.Close()
|
||||||
|
|
||||||
|
dialerCloseHooksMutex.RLock()
|
||||||
|
defer dialerCloseHooksMutex.RUnlock()
|
||||||
|
|
||||||
|
for _, hook := range dialerCloseHooks {
|
||||||
|
if err := hook(c.ID, &c.Conn); err != nil {
|
||||||
|
log.Errorf("Error executing dialer close hook: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
58
util/net/dial.go
Normal file
58
util/net/dial.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return net.DialUDP(network, laddr, raddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := NewDialer()
|
||||||
|
dialer.LocalAddr = laddr
|
||||||
|
|
||||||
|
conn, err := dialer.Dial(network, raddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
|
||||||
|
if !ok {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return udpConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return net.DialTCP(network, laddr, raddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := NewDialer()
|
||||||
|
dialer.LocalAddr = laddr
|
||||||
|
|
||||||
|
conn, err := dialer.Dial(network, raddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
|
||||||
|
if !ok {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tcpConn, nil
|
||||||
|
}
|
@@ -1,25 +0,0 @@
|
|||||||
package net
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (d *Dialer) init() {
|
|
||||||
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
|
|
||||||
err := c.Control(func(fd uintptr) {
|
|
||||||
androidProtectSocketLock.Lock()
|
|
||||||
f := androidProtectSocket
|
|
||||||
androidProtectSocketLock.Unlock()
|
|
||||||
if f == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ok := f(int32(fd))
|
|
||||||
if !ok {
|
|
||||||
log.Errorf("failed to protect socket: %d", fd)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
@@ -81,28 +81,6 @@ func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
|||||||
return d.DialContext(context.Background(), network, address)
|
return d.DialContext(context.Background(), network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conn wraps a net.Conn to override the Close method
|
|
||||||
type Conn struct {
|
|
||||||
net.Conn
|
|
||||||
ID ConnectionID
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
|
||||||
func (c *Conn) Close() error {
|
|
||||||
err := c.Conn.Close()
|
|
||||||
|
|
||||||
dialerCloseHooksMutex.RLock()
|
|
||||||
defer dialerCloseHooksMutex.RUnlock()
|
|
||||||
|
|
||||||
for _, hook := range dialerCloseHooks {
|
|
||||||
if err := hook(c.ID, &c.Conn); err != nil {
|
|
||||||
log.Errorf("Error executing dialer close hook: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
|
func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
|
||||||
host, _, err := net.SplitHostPort(address)
|
host, _, err := net.SplitHostPort(address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -127,51 +105,3 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r
|
|||||||
|
|
||||||
return result.ErrorOrNil()
|
return result.ErrorOrNil()
|
||||||
}
|
}
|
||||||
|
|
||||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
|
||||||
if CustomRoutingDisabled() {
|
|
||||||
return net.DialUDP(network, laddr, raddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
dialer := NewDialer()
|
|
||||||
dialer.LocalAddr = laddr
|
|
||||||
|
|
||||||
conn, err := dialer.Dial(network, raddr.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
|
|
||||||
if !ok {
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
log.Errorf("Failed to close connection: %v", err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
return udpConn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
|
||||||
if CustomRoutingDisabled() {
|
|
||||||
return net.DialTCP(network, laddr, raddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
dialer := NewDialer()
|
|
||||||
dialer.LocalAddr = laddr
|
|
||||||
|
|
||||||
conn, err := dialer.Dial(network, raddr.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
|
|
||||||
if !ok {
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
log.Errorf("Failed to close connection: %v", err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tcpConn, nil
|
|
||||||
}
|
|
5
util/net/dialer_init_android.go
Normal file
5
util/net/dialer_init_android.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
func (d *Dialer) init() {
|
||||||
|
d.Dialer.Control = ControlProtectSocket
|
||||||
|
}
|
@@ -7,6 +7,6 @@ import "syscall"
|
|||||||
// init configures the net.Dialer Control function to set the fwmark on the socket
|
// init configures the net.Dialer Control function to set the fwmark on the socket
|
||||||
func (d *Dialer) init() {
|
func (d *Dialer) init() {
|
||||||
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
|
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
|
||||||
return SetRawSocketMark(c)
|
return setRawSocketMark(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
@@ -3,4 +3,5 @@
|
|||||||
package net
|
package net
|
||||||
|
|
||||||
func (d *Dialer) init() {
|
func (d *Dialer) init() {
|
||||||
|
// implemented on Linux and Android only
|
||||||
}
|
}
|
29
util/net/env.go
Normal file
29
util/net/env.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
||||||
|
envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CustomRoutingDisabled() bool {
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return os.Getenv(envDisableCustomRouting) == "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
func SkipSocketMark() bool {
|
||||||
|
if skipSocketMark := os.Getenv(envSkipSocketMark); skipSocketMark == "true" {
|
||||||
|
log.Infof("%s is set to true, skipping SO_MARK", envSkipSocketMark)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
37
util/net/listen.go
Normal file
37
util/net/listen.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pion/transport/v3"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ListenUDP listens on the network address and returns a transport.UDPConn
|
||||||
|
// which includes support for write and close hooks.
|
||||||
|
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return net.ListenUDP(network, laddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("listen UDP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
packetConn := conn.(*PacketConn)
|
||||||
|
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
|
||||||
|
if !ok {
|
||||||
|
if err := packetConn.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
|
||||||
|
}
|
@@ -1,26 +0,0 @@
|
|||||||
package net
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
|
||||||
func (l *ListenerConfig) init() {
|
|
||||||
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
|
|
||||||
err := c.Control(func(fd uintptr) {
|
|
||||||
androidProtectSocketLock.Lock()
|
|
||||||
f := androidProtectSocket
|
|
||||||
androidProtectSocketLock.Unlock()
|
|
||||||
if f == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ok := f(int32(fd))
|
|
||||||
if !ok {
|
|
||||||
log.Errorf("failed to protect listener socket: %d", fd)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
6
util/net/listener_init_android.go
Normal file
6
util/net/listener_init_android.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
||||||
|
func (l *ListenerConfig) init() {
|
||||||
|
l.ListenConfig.Control = ControlProtectSocket
|
||||||
|
}
|
@@ -9,6 +9,6 @@ import (
|
|||||||
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
||||||
func (l *ListenerConfig) init() {
|
func (l *ListenerConfig) init() {
|
||||||
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
|
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
|
||||||
return SetRawSocketMark(c)
|
return setRawSocketMark(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
@@ -3,4 +3,5 @@
|
|||||||
package net
|
package net
|
||||||
|
|
||||||
func (l *ListenerConfig) init() {
|
func (l *ListenerConfig) init() {
|
||||||
|
// implemented on Linux and Android only
|
||||||
}
|
}
|
@@ -8,7 +8,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/pion/transport/v3"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -146,27 +145,3 @@ func closeConn(id ConnectionID, conn net.PacketConn) error {
|
|||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenUDP listens on the network address and returns a transport.UDPConn
|
|
||||||
// which includes support for write and close hooks.
|
|
||||||
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
|
|
||||||
if CustomRoutingDisabled() {
|
|
||||||
return net.ListenUDP(network, laddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("listen UDP: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
packetConn := conn.(*PacketConn)
|
|
||||||
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
|
|
||||||
if !ok {
|
|
||||||
if err := packetConn.Close(); err != nil {
|
|
||||||
log.Errorf("Failed to close connection: %v", err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
|
|
||||||
}
|
|
@@ -2,9 +2,6 @@ package net
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
@@ -16,8 +13,6 @@ const (
|
|||||||
PreroutingFwmarkRedirected = 0x1BD01
|
PreroutingFwmarkRedirected = 0x1BD01
|
||||||
PreroutingFwmarkMasquerade = 0x1BD11
|
PreroutingFwmarkMasquerade = 0x1BD11
|
||||||
PreroutingFwmarkMasqueradeReturn = 0x1BD12
|
PreroutingFwmarkMasqueradeReturn = 0x1BD12
|
||||||
|
|
||||||
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionID provides a globally unique identifier for network connections.
|
// ConnectionID provides a globally unique identifier for network connections.
|
||||||
@@ -31,10 +26,3 @@ type RemoveHookFunc func(connID ConnectionID) error
|
|||||||
func GenerateConnID() ConnectionID {
|
func GenerateConnID() ConnectionID {
|
||||||
return ConnectionID(uuid.NewString())
|
return ConnectionID(uuid.NewString())
|
||||||
}
|
}
|
||||||
|
|
||||||
func CustomRoutingDisabled() bool {
|
|
||||||
if netstack.IsEnabled() {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return os.Getenv(envDisableCustomRouting) == "true"
|
|
||||||
}
|
|
||||||
|
@@ -4,29 +4,42 @@ package net
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK"
|
|
||||||
|
|
||||||
// SetSocketMark sets the SO_MARK option on the given socket connection
|
// SetSocketMark sets the SO_MARK option on the given socket connection
|
||||||
func SetSocketMark(conn syscall.Conn) error {
|
func SetSocketMark(conn syscall.Conn) error {
|
||||||
|
if isSocketMarkDisabled() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
sysconn, err := conn.SyscallConn()
|
sysconn, err := conn.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get raw conn: %w", err)
|
return fmt.Errorf("get raw conn: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return SetRawSocketMark(sysconn)
|
return setRawSocketMark(sysconn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetRawSocketMark(conn syscall.RawConn) error {
|
// SetSocketOpt sets the SO_MARK option on the given file descriptor
|
||||||
|
func SetSocketOpt(fd int) error {
|
||||||
|
if isSocketMarkDisabled() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return setSocketOptInt(fd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setRawSocketMark(conn syscall.RawConn) error {
|
||||||
var setErr error
|
var setErr error
|
||||||
|
|
||||||
err := conn.Control(func(fd uintptr) {
|
err := conn.Control(func(fd uintptr) {
|
||||||
setErr = SetSocketOpt(int(fd))
|
if isSocketMarkDisabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
setErr = setSocketOptInt(int(fd))
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("control: %w", err)
|
return fmt.Errorf("control: %w", err)
|
||||||
@@ -39,17 +52,18 @@ func SetRawSocketMark(conn syscall.RawConn) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetSocketOpt(fd int) error {
|
func setSocketOptInt(fd int) error {
|
||||||
if CustomRoutingDisabled() {
|
|
||||||
log.Infof("Custom routing is disabled, skipping SO_MARK")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for the new environment variable
|
|
||||||
if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" {
|
|
||||||
log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
|
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isSocketMarkDisabled() bool {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
log.Infof("Custom routing is disabled, skipping SO_MARK")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if SkipSocketMark() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@@ -1,14 +1,42 @@
|
|||||||
package net
|
package net
|
||||||
|
|
||||||
import "sync"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
androidProtectSocketLock sync.Mutex
|
androidProtectSocketLock sync.Mutex
|
||||||
androidProtectSocket func(fd int32) bool
|
androidProtectSocket func(fd int32) bool
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetAndroidProtectSocketFn(f func(fd int32) bool) {
|
func SetAndroidProtectSocketFn(fn func(fd int32) bool) {
|
||||||
androidProtectSocketLock.Lock()
|
androidProtectSocketLock.Lock()
|
||||||
androidProtectSocket = f
|
androidProtectSocket = fn
|
||||||
androidProtectSocketLock.Unlock()
|
androidProtectSocketLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ControlProtectSocket is a Control function that sets the fwmark on the socket
|
||||||
|
func ControlProtectSocket(_, _ string, c syscall.RawConn) error {
|
||||||
|
var aErr error
|
||||||
|
err := c.Control(func(fd uintptr) {
|
||||||
|
androidProtectSocketLock.Lock()
|
||||||
|
defer androidProtectSocketLock.Unlock()
|
||||||
|
|
||||||
|
if androidProtectSocket == nil {
|
||||||
|
aErr = fmt.Errorf("socket protection function not set")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !androidProtectSocket(int32(fd)) {
|
||||||
|
aErr = fmt.Errorf("failed to protect socket via Android")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return aErr
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user