mirror of
https://github.com/netbirdio/netbird.git
synced 2025-01-19 12:29:27 +01:00
Merge branch 'policy-get-account-refactoring' into dns-get-account-refactoring
This commit is contained in:
commit
79822cdc15
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.0.16"
|
||||
SIGN_PIPE_VER: "v0.0.17"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
||||
|
@ -17,8 +17,12 @@
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
|
||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://gurubase.io/g/netbird">
|
||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
@ -30,7 +34,7 @@
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">Slack channel</a>
|
||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
|
||||
<br/>
|
||||
|
||||
</strong>
|
||||
|
@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
// persist early
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -197,7 +199,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
|
||||
var chain *nftables.Chain
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameForward {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||
chain = c
|
||||
break
|
||||
}
|
||||
@ -274,7 +276,7 @@ func (m *Manager) resetNetbirdInputRules() error {
|
||||
|
||||
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == "filter" && c.Name == "INPUT" {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||
rules, err := m.rConn.GetRules(c.Table, c)
|
||||
if err != nil {
|
||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||
@ -349,7 +351,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: []byte(allowNetbirdInputRuleID),
|
||||
}
|
||||
|
@ -1,9 +1,11 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -225,3 +227,105 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runIptablesSave(t *testing.T) (string, string) {
|
||||
t.Helper()
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd := exec.Command("iptables-save")
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
require.NoError(t, err, "iptables-save failed to run")
|
||||
|
||||
return stdout.String(), stderr.String()
|
||||
}
|
||||
|
||||
func verifyIptablesOutput(t *testing.T, stdout, stderr string) {
|
||||
t.Helper()
|
||||
// Check for any incompatibility warnings
|
||||
require.NotContains(t,
|
||||
stderr,
|
||||
"incompatible",
|
||||
"iptables-save produced compatibility warning. Full stderr: %s",
|
||||
stderr,
|
||||
)
|
||||
|
||||
// Verify standard tables are present
|
||||
expectedTables := []string{
|
||||
"*filter",
|
||||
"*nat",
|
||||
"*mangle",
|
||||
}
|
||||
|
||||
for _, table := range expectedTables {
|
||||
require.Contains(t,
|
||||
stdout,
|
||||
table,
|
||||
"iptables-save output missing expected table: %s\nFull stdout: %s",
|
||||
table,
|
||||
stdout,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("iptables-save"); err != nil {
|
||||
t.Skipf("iptables-save not available on this system: %v", err)
|
||||
}
|
||||
|
||||
// First ensure iptables-nft tables exist by running iptables-save
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
|
||||
manager, err := Create(ifaceMock)
|
||||
require.NoError(t, err, "failed to create manager")
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := manager.Reset(nil)
|
||||
require.NoError(t, err, "failed to reset manager state")
|
||||
|
||||
// Verify iptables output after reset
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
ip := net.ParseIP("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(
|
||||
ip,
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{80}},
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
"test rule",
|
||||
)
|
||||
require.NoError(t, err, "failed to add peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||
netip.MustParsePrefix("10.1.0.0/24"),
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
&fw.Port{Values: []int{443}},
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err, "failed to add route filtering rule")
|
||||
|
||||
pair := fw.RouterPair{
|
||||
Source: netip.MustParsePrefix("192.168.1.0/24"),
|
||||
Destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
Masquerade: true,
|
||||
}
|
||||
err = manager.AddNatRule(pair)
|
||||
require.NoError(t, err, "failed to add NAT rule")
|
||||
|
||||
stdout, stderr = runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
@ -239,7 +239,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// SetLegacyManagement doesn't need to be implemented for this manager
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errRouteNotSupported
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.SetLegacyManagement(isLegacy)
|
||||
}
|
||||
|
@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
|
||||
err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
|
||||
|
||||
// WriteOutConfig write put the prepared config to the given path
|
||||
func WriteOutConfig(path string, config *Config) error {
|
||||
return util.WriteJson(path, config)
|
||||
return util.WriteJson(context.Background(), path, config)
|
||||
}
|
||||
|
||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||
@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) {
|
||||
}
|
||||
|
||||
if updated {
|
||||
if err := util.WriteJson(input.ConfigPath, config); err != nil {
|
||||
if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
@ -157,7 +157,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
|
||||
engineCtx, cancel := context.WithCancel(c.ctx)
|
||||
defer func() {
|
||||
c.statusRecorder.MarkManagementDisconnected(state.err)
|
||||
_, err := state.Status()
|
||||
c.statusRecorder.MarkManagementDisconnected(err)
|
||||
c.statusRecorder.CleanLocalPeerState()
|
||||
cancel()
|
||||
}()
|
||||
@ -231,6 +232,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
|
||||
relayURLs, token := parseRelayInfo(loginResp)
|
||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String())
|
||||
c.statusRecorder.SetRelayMgr(relayManager)
|
||||
if len(relayURLs) > 0 {
|
||||
if token != nil {
|
||||
if err := relayManager.UpdateToken(token); err != nil {
|
||||
@ -241,9 +243,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
|
||||
log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", "))
|
||||
if err = relayManager.Serve(); err != nil {
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
c.statusRecorder.SetRelayMgr(relayManager)
|
||||
}
|
||||
|
||||
peerConfig := loginResp.GetPeerConfig()
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
@ -323,12 +322,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
// persist dns state right away
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
if err := s.stateManager.PersistState(ctx); err != nil {
|
||||
log.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
go func() {
|
||||
// persist dns state right away
|
||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||
log.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if s.searchDomainNotifier != nil {
|
||||
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
||||
@ -533,12 +532,11 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
||||
}
|
||||
|
||||
// persist dns state right away
|
||||
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
if err := s.stateManager.PersistState(ctx); err != nil {
|
||||
l.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
go func() {
|
||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||
l.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
|
||||
s.addHostRootZone()
|
||||
|
@ -782,7 +782,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Domains: []string{"customdomain.com"},
|
||||
Domains: []string{"google.com"},
|
||||
Primary: false,
|
||||
},
|
||||
},
|
||||
@ -804,7 +804,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
if ips[0] != zoneRecords[0].RData {
|
||||
t.Fatalf("invalid zone record: %v", err)
|
||||
}
|
||||
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
|
||||
_, err = resolver.LookupHost(context.Background(), "google.com")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -38,7 +39,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@ -171,7 +171,7 @@ type Engine struct {
|
||||
|
||||
relayManager *relayClient.Manager
|
||||
stateManager *statemanager.Manager
|
||||
srWatcher *guard.SRWatcher
|
||||
srWatcher *guard.SRWatcher
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@ -297,7 +297,7 @@ func (e *Engine) Stop() error {
|
||||
if err := e.stateManager.Stop(ctx); err != nil {
|
||||
return fmt.Errorf("failed to stop state manager: %w", err)
|
||||
}
|
||||
if err := e.stateManager.PersistState(ctx); err != nil {
|
||||
if err := e.stateManager.PersistState(context.Background()); err != nil {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
@ -538,6 +538,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
|
||||
relayMsg := wCfg.GetRelay()
|
||||
if relayMsg != nil {
|
||||
// when we receive token we expect valid address list too
|
||||
c := &auth.Token{
|
||||
Payload: relayMsg.GetTokenPayload(),
|
||||
Signature: relayMsg.GetTokenSignature(),
|
||||
@ -546,9 +547,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
log.Errorf("failed to update relay token: %v", err)
|
||||
return fmt.Errorf("update relay token: %w", err)
|
||||
}
|
||||
|
||||
e.relayManager.UpdateServerURLs(relayMsg.Urls)
|
||||
|
||||
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
||||
// We can ignore all errors because the guard will manage the reconnection retries.
|
||||
_ = e.relayManager.Serve()
|
||||
} else {
|
||||
e.relayManager.UpdateServerURLs(nil)
|
||||
}
|
||||
|
||||
// todo update relay address in the relay manager
|
||||
// todo update signal
|
||||
}
|
||||
|
||||
@ -641,6 +649,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
}
|
||||
|
||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wireguard interface is not initialized")
|
||||
}
|
||||
|
||||
if e.wgInterface.Address().String() != conf.Address {
|
||||
oldAddr := e.wgInterface.Address().String()
|
||||
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
|
||||
@ -1481,6 +1493,17 @@ func (e *Engine) stopDNSServer() {
|
||||
|
||||
// isChecksEqual checks if two slices of checks are equal.
|
||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||
for _, check := range checks {
|
||||
sort.Slice(check.Files, func(i, j int) bool {
|
||||
return check.Files[i] < check.Files[j]
|
||||
})
|
||||
}
|
||||
for _, oCheck := range oChecks {
|
||||
sort.Slice(oCheck.Files, func(i, j int) bool {
|
||||
return oCheck.Files[i] < oCheck.Files[j]
|
||||
})
|
||||
}
|
||||
|
||||
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
||||
return slices.Equal(checks.Files, oChecks.Files)
|
||||
})
|
||||
|
@ -1006,6 +1006,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CheckFilesEqual(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputChecks1 []*mgmtProto.Checks
|
||||
inputChecks2 []*mgmtProto.Checks
|
||||
expectedBool bool
|
||||
}{
|
||||
{
|
||||
name: "Equal Files In Equal Order Should Return True",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedBool: true,
|
||||
},
|
||||
{
|
||||
name: "Equal Files In Reverse Order Should Return True",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile2",
|
||||
"testfile1",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedBool: true,
|
||||
},
|
||||
{
|
||||
name: "Unequal Files Should Return False",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile3",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedBool: false,
|
||||
},
|
||||
{
|
||||
name: "Compared With Empty Should Return False",
|
||||
inputChecks1: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{
|
||||
"testfile1",
|
||||
"testfile2",
|
||||
},
|
||||
},
|
||||
},
|
||||
inputChecks2: []*mgmtProto.Checks{
|
||||
{
|
||||
Files: []string{},
|
||||
},
|
||||
},
|
||||
expectedBool: false,
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2)
|
||||
assert.Equal(t, testCase.expectedBool, result, "result should match expected bool")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
|
@ -676,25 +676,23 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
// extend the list of stun, turn servers with relay address
|
||||
relayStates := slices.Clone(d.relayStates)
|
||||
|
||||
var relayState relay.ProbeResult
|
||||
|
||||
// if the server connection is not established then we will use the general address
|
||||
// in case of connection we will use the instance specific address
|
||||
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
// TODO add their status
|
||||
if errors.Is(err, relayClient.ErrRelayClientNotConnected) {
|
||||
for _, r := range d.relayMgr.ServerURLs() {
|
||||
relayStates = append(relayStates, relay.ProbeResult{
|
||||
URI: r,
|
||||
})
|
||||
}
|
||||
return relayStates
|
||||
for _, r := range d.relayMgr.ServerURLs() {
|
||||
relayStates = append(relayStates, relay.ProbeResult{
|
||||
URI: r,
|
||||
Err: err,
|
||||
})
|
||||
}
|
||||
relayState.Err = err
|
||||
return relayStates
|
||||
}
|
||||
|
||||
relayState.URI = instanceAddr
|
||||
relayState := relay.ProbeResult{
|
||||
URI: instanceAddr,
|
||||
}
|
||||
return append(relayStates, relayState)
|
||||
}
|
||||
|
||||
|
@ -46,8 +46,6 @@ type WorkerICE struct {
|
||||
hasRelayOnLocally bool
|
||||
conn WorkerICECallbacks
|
||||
|
||||
selectedPriority ConnPriority
|
||||
|
||||
agent *ice.Agent
|
||||
muxAgent sync.Mutex
|
||||
|
||||
@ -95,10 +93,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
|
||||
var preferredCandidateTypes []ice.CandidateType
|
||||
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
|
||||
w.selectedPriority = connPriorityICEP2P
|
||||
preferredCandidateTypes = icemaker.CandidateTypesP2P()
|
||||
} else {
|
||||
w.selectedPriority = connPriorityICETurn
|
||||
preferredCandidateTypes = icemaker.CandidateTypes()
|
||||
}
|
||||
|
||||
@ -159,7 +155,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
RelayedOnLocal: isRelayCandidate(pair.Local),
|
||||
}
|
||||
w.log.Debugf("on ICE conn read to use ready")
|
||||
go w.conn.OnConnReady(w.selectedPriority, ci)
|
||||
go w.conn.OnConnReady(selectedPriority(pair), ci)
|
||||
}
|
||||
|
||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||
@ -394,3 +390,11 @@ func isRelayed(pair *ice.CandidatePair) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func selectedPriority(pair *ice.CandidatePair) ConnPriority {
|
||||
if isRelayed(pair) {
|
||||
return connPriorityICETurn
|
||||
} else {
|
||||
return connPriorityICEP2P
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
@ -227,6 +228,64 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
currentRoute: "route1",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "relayed routes with latency 0 should maintain previous choice",
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: true,
|
||||
latency: 0 * time.Millisecond,
|
||||
},
|
||||
"route2": {
|
||||
connected: true,
|
||||
relayed: true,
|
||||
latency: 0 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
"route2": {
|
||||
ID: "route2",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: "route1",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "p2p routes with latency 0 should maintain previous choice",
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
"route1": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
latency: 0 * time.Millisecond,
|
||||
},
|
||||
"route2": {
|
||||
connected: true,
|
||||
relayed: false,
|
||||
latency: 0 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
existingRoutes: map[route.ID]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer1",
|
||||
},
|
||||
"route2": {
|
||||
ID: "route2",
|
||||
Metric: route.MaxMetric,
|
||||
Peer: "peer2",
|
||||
},
|
||||
},
|
||||
currentRoute: "route1",
|
||||
expectedRouteID: "route1",
|
||||
},
|
||||
{
|
||||
name: "current route with bad score should be changed to route with better score",
|
||||
statuses: map[route.ID]routerPeerStatus{
|
||||
@ -287,6 +346,45 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// fill the test data with random routes
|
||||
for _, tc := range testCases {
|
||||
for i := 0; i < 50; i++ {
|
||||
dummyRoute := &route.Route{
|
||||
ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
|
||||
Metric: route.MinMetric,
|
||||
Peer: fmt.Sprintf("dummy_p1_%d", i),
|
||||
}
|
||||
tc.existingRoutes[dummyRoute.ID] = dummyRoute
|
||||
}
|
||||
for i := 0; i < 50; i++ {
|
||||
dummyRoute := &route.Route{
|
||||
ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
|
||||
Metric: route.MinMetric,
|
||||
Peer: fmt.Sprintf("dummy_p1_%d", i),
|
||||
}
|
||||
tc.existingRoutes[dummyRoute.ID] = dummyRoute
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
id := route.ID(fmt.Sprintf("dummy_p1_%d", i))
|
||||
dummyStatus := routerPeerStatus{
|
||||
connected: false,
|
||||
relayed: true,
|
||||
latency: 0,
|
||||
}
|
||||
tc.statuses[id] = dummyStatus
|
||||
}
|
||||
for i := 0; i < 50; i++ {
|
||||
id := route.ID(fmt.Sprintf("dummy_p2_%d", i))
|
||||
dummyStatus := routerPeerStatus{
|
||||
connected: false,
|
||||
relayed: true,
|
||||
latency: 0,
|
||||
}
|
||||
tc.statuses[id] = dummyStatus
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
currentRoute := &route.Route{
|
||||
|
@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error
|
||||
type Counter[Key comparable, I, O any] struct {
|
||||
// refCountMap keeps track of the reference Ref for keys
|
||||
refCountMap map[Key]Ref[O]
|
||||
refCountMu sync.Mutex
|
||||
mu sync.Mutex
|
||||
// idMap keeps track of the keys associated with an ID for removal
|
||||
idMap map[string][]Key
|
||||
idMu sync.Mutex
|
||||
add AddFunc[Key, I, O]
|
||||
remove RemoveFunc[Key, O]
|
||||
}
|
||||
@ -75,10 +74,8 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
|
||||
func (rm *Counter[Key, I, O]) LoadData(
|
||||
existingCounter *Counter[Key, I, O],
|
||||
) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
rm.refCountMap = existingCounter.refCountMap
|
||||
rm.idMap = existingCounter.idMap
|
||||
@ -87,8 +84,8 @@ func (rm *Counter[Key, I, O]) LoadData(
|
||||
// Get retrieves the current reference count and associated data for a key.
|
||||
// If the key doesn't exist, it returns a zero value Ref and false.
|
||||
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
ref, ok := rm.refCountMap[key]
|
||||
return ref, ok
|
||||
@ -97,9 +94,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
|
||||
// Increment increments the reference count for the given key.
|
||||
// If this is the first reference to the key, the AddFunc is called.
|
||||
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
return rm.increment(key, in)
|
||||
}
|
||||
|
||||
func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) {
|
||||
ref := rm.refCountMap[key]
|
||||
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
|
||||
|
||||
@ -126,10 +127,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
|
||||
// IncrementWithID increments the reference count for the given key and groups it under the given ID.
|
||||
// If this is the first reference to the key, the AddFunc is called.
|
||||
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
ref, err := rm.Increment(key, in)
|
||||
ref, err := rm.increment(key, in)
|
||||
if err != nil {
|
||||
return ref, fmt.Errorf("with ID: %w", err)
|
||||
}
|
||||
@ -141,9 +142,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O],
|
||||
// Decrement decrements the reference count for the given key.
|
||||
// If the reference count reaches 0, the RemoveFunc is called.
|
||||
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
return rm.decrement(key)
|
||||
}
|
||||
|
||||
func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) {
|
||||
ref, ok := rm.refCountMap[key]
|
||||
if !ok {
|
||||
logCallerF("No reference found for key %v", key)
|
||||
@ -168,12 +172,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
|
||||
// DecrementWithID decrements the reference count for all keys associated with the given ID.
|
||||
// If the reference count reaches 0, the RemoveFunc is called.
|
||||
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, key := range rm.idMap[id] {
|
||||
if _, err := rm.Decrement(key); err != nil {
|
||||
if _, err := rm.decrement(key); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
}
|
||||
@ -184,10 +188,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
|
||||
|
||||
// Flush removes all references and calls RemoveFunc for each key.
|
||||
func (rm *Counter[Key, I, O]) Flush() error {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for key := range rm.refCountMap {
|
||||
@ -206,10 +208,8 @@ func (rm *Counter[Key, I, O]) Flush() error {
|
||||
|
||||
// Clear removes all references without calling RemoveFunc.
|
||||
func (rm *Counter[Key, I, O]) Clear() {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
clear(rm.refCountMap)
|
||||
clear(rm.idMap)
|
||||
@ -217,10 +217,8 @@ func (rm *Counter[Key, I, O]) Clear() {
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for Counter.
|
||||
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
|
||||
rm.refCountMu.Lock()
|
||||
defer rm.refCountMu.Unlock()
|
||||
rm.idMu.Lock()
|
||||
defer rm.idMu.Unlock()
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
|
||||
return json.Marshal(struct {
|
||||
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
|
||||
|
@ -2,31 +2,28 @@ package systemops
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
)
|
||||
|
||||
type ShutdownState struct {
|
||||
Counter *ExclusionCounter `json:"counter,omitempty"`
|
||||
mu sync.RWMutex
|
||||
}
|
||||
type ShutdownState ExclusionCounter
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "route_state"
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if s.Counter == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sysops := NewSysOps(nil, nil)
|
||||
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
|
||||
sysops.refCounter.LoadData(s.Counter)
|
||||
sysops.refCounter.LoadData((*ExclusionCounter)(s))
|
||||
|
||||
return sysops.refCounter.Flush()
|
||||
}
|
||||
|
||||
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
|
||||
return (*ExclusionCounter)(s).MarshalJSON()
|
||||
}
|
||||
|
||||
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
|
||||
return (*ExclusionCounter)(s).UnmarshalJSON(data)
|
||||
}
|
||||
|
@ -57,30 +57,19 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
|
||||
return nexthop, refcounter.ErrIgnore
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
|
||||
return nexthop, err
|
||||
},
|
||||
func(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
// remove from state even if we have trouble removing it from the route table
|
||||
// it could be already gone
|
||||
r.updateState(stateManager)
|
||||
|
||||
return r.removeFromRouteTable(prefix, nexthop)
|
||||
},
|
||||
r.removeFromRouteTable,
|
||||
)
|
||||
|
||||
r.refCounter = refCounter
|
||||
|
||||
return r.setupHooks(initAddresses)
|
||||
return r.setupHooks(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
// updateState updates state on every change so it will be persisted regularly
|
||||
func (r *SysOps) updateState(stateManager *statemanager.Manager) {
|
||||
state := getState(stateManager)
|
||||
|
||||
state.Counter = r.refCounter
|
||||
|
||||
if err := stateManager.UpdateState(state); err != nil {
|
||||
if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
@ -336,7 +325,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
||||
return r.removeFromRouteTable(prefix, nextHop)
|
||||
}
|
||||
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
@ -347,6 +336,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
|
||||
return fmt.Errorf("adding route reference: %v", err)
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
|
||||
return nil
|
||||
}
|
||||
afterHook := func(connID nbnet.ConnectionID) error {
|
||||
@ -354,6 +345,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -532,14 +525,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
|
||||
// Return true if the longest matching prefix is from vpnRoutes
|
||||
return isVpn, longestPrefix
|
||||
}
|
||||
|
||||
func getState(stateManager *statemanager.Manager) *ShutdownState {
|
||||
var shutdownState *ShutdownState
|
||||
if state := stateManager.GetState(shutdownState); state != nil {
|
||||
shutdownState = state.(*ShutdownState)
|
||||
} else {
|
||||
shutdownState = &ShutdownState{}
|
||||
}
|
||||
|
||||
return shutdownState
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ type ruleParams struct {
|
||||
|
||||
// isLegacy determines whether to use the legacy routing setup
|
||||
func isLegacy() bool {
|
||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
|
||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true"
|
||||
}
|
||||
|
||||
// setIsLegacy sets the legacy routing setup
|
||||
@ -92,17 +92,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
if err = addRoutingTableName(); err != nil {
|
||||
log.Errorf("Error adding routing table name: %v", err)
|
||||
}
|
||||
|
||||
originalValues, err := sysctl.Setup(r.wgInterface)
|
||||
if err != nil {
|
||||
log.Errorf("Error setting up sysctl: %v", err)
|
||||
sysctlFailed = true
|
||||
}
|
||||
originalSysctl = originalValues
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
|
||||
@ -123,6 +112,17 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
||||
}
|
||||
}
|
||||
|
||||
if err = addRoutingTableName(); err != nil {
|
||||
log.Errorf("Error adding routing table name: %v", err)
|
||||
}
|
||||
|
||||
originalValues, err := sysctl.Setup(r.wgInterface)
|
||||
if err != nil {
|
||||
log.Errorf("Error setting up sysctl: %v", err)
|
||||
sysctlFailed = true
|
||||
}
|
||||
originalSysctl = originalValues
|
||||
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
@ -450,7 +450,7 @@ func addRule(params ruleParams) error {
|
||||
rule.Invert = params.invert
|
||||
rule.SuppressPrefixlen = params.suppressPrefix
|
||||
|
||||
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) {
|
||||
return fmt.Errorf("add routing rule: %w", err)
|
||||
}
|
||||
|
||||
@ -467,7 +467,7 @@ func removeRule(params ruleParams) error {
|
||||
rule.Priority = params.priority
|
||||
rule.SuppressPrefixlen = params.suppressPrefix
|
||||
|
||||
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) {
|
||||
return fmt.Errorf("remove routing rule: %w", err)
|
||||
}
|
||||
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// State interface defines the methods that all state types must implement
|
||||
@ -73,15 +74,15 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
if m.cancel == nil {
|
||||
return nil
|
||||
}
|
||||
m.cancel()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-m.done:
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-m.done:
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -178,25 +179,18 @@ func (m *Manager) PersistState(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
bs, err := marshalWithPanicRecovery(m.states)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal states: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
|
||||
start := time.Now()
|
||||
go func() {
|
||||
data, err := json.MarshalIndent(m.states, "", " ")
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("marshal states: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// nolint:gosec
|
||||
if err := os.WriteFile(m.filePath, data, 0640); err != nil {
|
||||
done <- fmt.Errorf("write state file: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
done <- nil
|
||||
done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs)
|
||||
}()
|
||||
|
||||
select {
|
||||
@ -208,7 +202,7 @@ func (m *Manager) PersistState(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty))
|
||||
log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start))
|
||||
|
||||
clear(m.dirty)
|
||||
|
||||
@ -296,3 +290,19 @@ func (m *Manager) PerformCleanup() error {
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func marshalWithPanicRecovery(v any) ([]byte, error) {
|
||||
var bs []byte
|
||||
var err error
|
||||
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic during marshal: %v", r)
|
||||
}
|
||||
}()
|
||||
bs, err = json.Marshal(v)
|
||||
}()
|
||||
|
||||
return bs, err
|
||||
}
|
||||
|
@ -4,32 +4,20 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// GetDefaultStatePath returns the path to the state file based on the operating system
|
||||
// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist.
|
||||
// It returns an empty string if the path cannot be determined.
|
||||
func GetDefaultStatePath() string {
|
||||
var path string
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
|
||||
return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
|
||||
case "darwin", "linux":
|
||||
path = "/var/lib/netbird/state.json"
|
||||
return "/var/lib/netbird/state.json"
|
||||
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||
path = "/var/db/netbird/state.json"
|
||||
// ios/android don't need state
|
||||
default:
|
||||
return ""
|
||||
return "/var/db/netbird/state.json"
|
||||
}
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
|
||||
return ""
|
||||
}
|
||||
return ""
|
||||
|
||||
return path
|
||||
}
|
||||
|
@ -965,7 +965,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgro
|
||||
}
|
||||
|
||||
// UserGroupsAddToPeers adds groups to all peers of user
|
||||
func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
|
||||
func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string {
|
||||
groupUpdates := make(map[string][]string)
|
||||
|
||||
userPeers := make(map[string]struct{})
|
||||
for pid, peer := range a.Peers {
|
||||
if peer.UserID == userID {
|
||||
@ -979,6 +981,8 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
|
||||
continue
|
||||
}
|
||||
|
||||
oldPeers := group.Peers
|
||||
|
||||
groupPeers := make(map[string]struct{})
|
||||
for _, pid := range group.Peers {
|
||||
groupPeers[pid] = struct{}{}
|
||||
@ -992,16 +996,25 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
|
||||
for pid := range groupPeers {
|
||||
group.Peers = append(group.Peers, pid)
|
||||
}
|
||||
|
||||
groupUpdates[gid] = difference(group.Peers, oldPeers)
|
||||
}
|
||||
|
||||
return groupUpdates
|
||||
}
|
||||
|
||||
// UserGroupsRemoveFromPeers removes groups from all peers of user
|
||||
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
|
||||
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string {
|
||||
groupUpdates := make(map[string][]string)
|
||||
|
||||
for _, gid := range groups {
|
||||
group, ok := a.Groups[gid]
|
||||
if !ok || group.Name == "All" {
|
||||
continue
|
||||
}
|
||||
|
||||
oldPeers := group.Peers
|
||||
|
||||
update := make([]string, 0, len(group.Peers))
|
||||
for _, pid := range group.Peers {
|
||||
peer, ok := a.Peers[pid]
|
||||
@ -1013,7 +1026,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
|
||||
}
|
||||
}
|
||||
group.Peers = update
|
||||
groupUpdates[gid] = difference(oldPeers, group.Peers)
|
||||
}
|
||||
|
||||
return groupUpdates
|
||||
}
|
||||
|
||||
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||
@ -1175,6 +1191,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("groups propagation failed: %w", err)
|
||||
}
|
||||
|
||||
updatedAccount := account.UpdateSettings(newSettings)
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
@ -1185,21 +1206,39 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return updatedAccount, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
||||
event := activity.AccountPeerInactivityExpirationEnabled
|
||||
if !newSettings.PeerInactivityExpirationEnabled {
|
||||
event = activity.AccountPeerInactivityExpirationDisabled
|
||||
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
||||
func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||
if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled {
|
||||
if newSettings.GroupsPropagationEnabled {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil)
|
||||
// Todo: retroactively add user groups to all peers
|
||||
} else {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||
}
|
||||
|
||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||
|
||||
if newSettings.PeerInactivityExpirationEnabled {
|
||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||
oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration
|
||||
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
}
|
||||
} else {
|
||||
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
||||
event := activity.AccountPeerInactivityExpirationEnabled
|
||||
if !newSettings.PeerInactivityExpirationEnabled {
|
||||
event = activity.AccountPeerInactivityExpirationDisabled
|
||||
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
||||
} else {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -2323,7 +2362,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
|
||||
|
||||
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -2339,6 +2378,9 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
|
||||
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||
defer unlockPeer()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -148,6 +148,9 @@ const (
|
||||
AccountPeerInactivityExpirationDurationUpdated Activity = 67
|
||||
|
||||
SetupKeyDeleted Activity = 68
|
||||
|
||||
UserGroupPropagationEnabled Activity = 69
|
||||
UserGroupPropagationDisabled Activity = 70
|
||||
)
|
||||
|
||||
var activityMap = map[Activity]Code{
|
||||
@ -222,6 +225,9 @@ var activityMap = map[Activity]Code{
|
||||
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
|
||||
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
|
||||
SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"},
|
||||
|
||||
UserGroupPropagationEnabled: {"User group propagation enabled", "account.setting.group.propagation.enable"},
|
||||
UserGroupPropagationDisabled: {"User group propagation disabled", "account.setting.group.propagation.disable"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
||||
// It is recommended to call it with locking FileStore.mux
|
||||
func (s *FileStore) persist(ctx context.Context, file string) error {
|
||||
start := time.Now()
|
||||
err := util.WriteJson(file, s)
|
||||
err := util.WriteJson(context.Background(), file, s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -6,11 +6,12 @@ import (
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
@ -27,11 +28,6 @@ func (e *GroupLinkError) Error() string {
|
||||
|
||||
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
|
||||
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -41,7 +37,7 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() && settings.RegularUsersViewBlocked {
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
@ -215,48 +211,9 @@ func difference(a, b []string) []string {
|
||||
|
||||
// DeleteGroup object of the peers.
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
if user.IsRegularUser() {
|
||||
return status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
var group *nbgroup.Group
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
group, err = transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if group.IsGroupAll() {
|
||||
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||
}
|
||||
|
||||
if err = validateDeleteGroup(ctx, transaction, group, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta())
|
||||
|
||||
return nil
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
return am.DeleteGroups(ctx, accountID, userID, []string{groupID})
|
||||
}
|
||||
|
||||
// DeleteGroups deletes groups from an account.
|
||||
@ -285,13 +242,14 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
|
||||
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
@ -318,12 +276,15 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var group *nbgroup.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -356,12 +317,15 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
var group *nbgroup.Group
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -430,13 +394,17 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.
|
||||
if group.Issued == nbgroup.GroupIssuedIntegration {
|
||||
executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
return err
|
||||
}
|
||||
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
|
||||
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
|
||||
}
|
||||
}
|
||||
|
||||
if group.IsGroupAll() {
|
||||
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||
}
|
||||
|
||||
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
||||
}
|
||||
|
@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
|
||||
{
|
||||
name: "delete non-existent group",
|
||||
groupIDs: []string{"non-existent-group"},
|
||||
expectedDeleted: []string{"non-existent-group"},
|
||||
expectedReasons: []string{"group: non-existent-group not found"},
|
||||
},
|
||||
{
|
||||
name: "delete multiple groups with mixed results",
|
||||
|
@ -439,17 +439,13 @@ components:
|
||||
example: 5
|
||||
required:
|
||||
- accessible_peers_count
|
||||
SetupKey:
|
||||
SetupKeyBase:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
description: Setup Key ID
|
||||
type: string
|
||||
example: 2531583362
|
||||
key:
|
||||
description: Setup Key value
|
||||
type: string
|
||||
example: A616097E-FCF0-48FA-9354-CA4A61142761
|
||||
name:
|
||||
description: Setup key name identifier
|
||||
type: string
|
||||
@ -518,22 +514,31 @@ components:
|
||||
- updated_at
|
||||
- usage_limit
|
||||
- ephemeral
|
||||
SetupKeyClear:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/SetupKeyBase'
|
||||
- type: object
|
||||
properties:
|
||||
key:
|
||||
description: Setup Key as plain text
|
||||
type: string
|
||||
example: A616097E-FCF0-48FA-9354-CA4A61142761
|
||||
required:
|
||||
- key
|
||||
SetupKey:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/SetupKeyBase'
|
||||
- type: object
|
||||
properties:
|
||||
key:
|
||||
description: Setup Key as secret
|
||||
type: string
|
||||
example: A6160****
|
||||
required:
|
||||
- key
|
||||
SetupKeyRequest:
|
||||
type: object
|
||||
properties:
|
||||
name:
|
||||
description: Setup Key name
|
||||
type: string
|
||||
example: Default key
|
||||
type:
|
||||
description: Setup key type, one-off for single time usage and reusable
|
||||
type: string
|
||||
example: reusable
|
||||
expires_in:
|
||||
description: Expiration time in seconds, 0 will mean the key never expires
|
||||
type: integer
|
||||
minimum: 0
|
||||
example: 86400
|
||||
revoked:
|
||||
description: Setup key revocation status
|
||||
type: boolean
|
||||
@ -544,21 +549,9 @@ components:
|
||||
items:
|
||||
type: string
|
||||
example: "ch8i4ug6lnn4g9hqv7m0"
|
||||
usage_limit:
|
||||
description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
type: integer
|
||||
example: 0
|
||||
ephemeral:
|
||||
description: Indicate that the peer will be ephemeral or not
|
||||
type: boolean
|
||||
example: true
|
||||
required:
|
||||
- name
|
||||
- type
|
||||
- expires_in
|
||||
- revoked
|
||||
- auto_groups
|
||||
- usage_limit
|
||||
CreateSetupKeyRequest:
|
||||
type: object
|
||||
properties:
|
||||
@ -1943,7 +1936,7 @@ paths:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/SetupKey'
|
||||
$ref: '#/components/schemas/SetupKeyClear'
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
|
@ -1062,7 +1062,94 @@ type SetupKey struct {
|
||||
// Id Setup Key ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// Key Setup Key value
|
||||
// Key Setup Key as secret
|
||||
Key string `json:"key"`
|
||||
|
||||
// LastUsed Setup key last usage date
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
|
||||
// Name Setup key name identifier
|
||||
Name string `json:"name"`
|
||||
|
||||
// Revoked Setup key revocation status
|
||||
Revoked bool `json:"revoked"`
|
||||
|
||||
// State Setup key status, "valid", "overused","expired" or "revoked"
|
||||
State string `json:"state"`
|
||||
|
||||
// Type Setup key type, one-off for single time usage and reusable
|
||||
Type string `json:"type"`
|
||||
|
||||
// UpdatedAt Setup key last update date
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
UsageLimit int `json:"usage_limit"`
|
||||
|
||||
// UsedTimes Usage count of setup key
|
||||
UsedTimes int `json:"used_times"`
|
||||
|
||||
// Valid Setup key validity status
|
||||
Valid bool `json:"valid"`
|
||||
}
|
||||
|
||||
// SetupKeyBase defines model for SetupKeyBase.
|
||||
type SetupKeyBase struct {
|
||||
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Ephemeral Indicate that the peer will be ephemeral or not
|
||||
Ephemeral bool `json:"ephemeral"`
|
||||
|
||||
// Expires Setup Key expiration date
|
||||
Expires time.Time `json:"expires"`
|
||||
|
||||
// Id Setup Key ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// LastUsed Setup key last usage date
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
|
||||
// Name Setup key name identifier
|
||||
Name string `json:"name"`
|
||||
|
||||
// Revoked Setup key revocation status
|
||||
Revoked bool `json:"revoked"`
|
||||
|
||||
// State Setup key status, "valid", "overused","expired" or "revoked"
|
||||
State string `json:"state"`
|
||||
|
||||
// Type Setup key type, one-off for single time usage and reusable
|
||||
Type string `json:"type"`
|
||||
|
||||
// UpdatedAt Setup key last update date
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
UsageLimit int `json:"usage_limit"`
|
||||
|
||||
// UsedTimes Usage count of setup key
|
||||
UsedTimes int `json:"used_times"`
|
||||
|
||||
// Valid Setup key validity status
|
||||
Valid bool `json:"valid"`
|
||||
}
|
||||
|
||||
// SetupKeyClear defines model for SetupKeyClear.
|
||||
type SetupKeyClear struct {
|
||||
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Ephemeral Indicate that the peer will be ephemeral or not
|
||||
Ephemeral bool `json:"ephemeral"`
|
||||
|
||||
// Expires Setup Key expiration date
|
||||
Expires time.Time `json:"expires"`
|
||||
|
||||
// Id Setup Key ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// Key Setup Key as plain text
|
||||
Key string `json:"key"`
|
||||
|
||||
// LastUsed Setup key last usage date
|
||||
@ -1098,23 +1185,8 @@ type SetupKeyRequest struct {
|
||||
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||
AutoGroups []string `json:"auto_groups"`
|
||||
|
||||
// Ephemeral Indicate that the peer will be ephemeral or not
|
||||
Ephemeral *bool `json:"ephemeral,omitempty"`
|
||||
|
||||
// ExpiresIn Expiration time in seconds, 0 will mean the key never expires
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
|
||||
// Name Setup Key name
|
||||
Name string `json:"name"`
|
||||
|
||||
// Revoked Setup key revocation status
|
||||
Revoked bool `json:"revoked"`
|
||||
|
||||
// Type Setup key type, one-off for single time usage and reusable
|
||||
Type string `json:"type"`
|
||||
|
||||
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
||||
UsageLimit int `json:"usage_limit"`
|
||||
}
|
||||
|
||||
// User defines model for User.
|
||||
|
@ -184,14 +184,26 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
respBody := make([]*api.PeerBatch, 0, len(account.Peers))
|
||||
for _, peer := range account.Peers {
|
||||
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
groupsMap := map[string]*nbgroup.Group{}
|
||||
groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
for _, group := range groups {
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
respBody := make([]*api.PeerBatch, 0, len(peers))
|
||||
for _, peer := range peers {
|
||||
peerToReturn, err := h.checkPeerStatus(peer)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
groupMinimumInfo := toGroupsInfo(groupsMap, peer.ID)
|
||||
|
||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
|
||||
}
|
||||
@ -304,7 +316,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
|
||||
}
|
||||
|
||||
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
|
||||
var groupsInfo []api.GroupMinimum
|
||||
groupsInfo := []api.GroupMinimum{}
|
||||
groupsChecked := make(map[string]struct{})
|
||||
for _, group := range groups {
|
||||
_, ok := groupsChecked[group.ID]
|
||||
|
@ -128,8 +128,13 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
||||
Description: req.Description,
|
||||
}
|
||||
for _, rule := range req.Rules {
|
||||
var ruleID string
|
||||
if rule.Id != nil {
|
||||
ruleID = *rule.Id
|
||||
}
|
||||
|
||||
pr := server.PolicyRule{
|
||||
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
|
||||
ID: ruleID,
|
||||
PolicyID: policyID,
|
||||
Name: rule.Name,
|
||||
Destinations: rule.Destinations,
|
||||
|
@ -137,11 +137,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.AutoGroups == nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
|
||||
return
|
||||
@ -150,7 +145,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
newKey := &server.SetupKey{}
|
||||
newKey.AutoGroups = req.AutoGroups
|
||||
newKey.Revoked = req.Revoked
|
||||
newKey.Name = req.Name
|
||||
newKey.Id = keyID
|
||||
|
||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
|
||||
|
@ -168,6 +168,8 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
|
||||
|
||||
account.UpdatePeer(peer)
|
||||
|
||||
log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)
|
||||
|
||||
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to save peer status: %w", err)
|
||||
@ -669,6 +671,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
||||
|
||||
updated := peer.UpdateMetaIfNew(sync.Meta)
|
||||
if updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
|
||||
account.Peers[peer.ID] = peer
|
||||
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
|
||||
err = am.Store.SavePeer(ctx, account.Id, peer)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err)
|
||||
@ -805,6 +810,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
|
||||
updated := peer.UpdateMetaIfNew(login.Meta)
|
||||
if updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
|
||||
shouldStorePeer = true
|
||||
}
|
||||
|
||||
@ -997,6 +1003,12 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
|
||||
// updateAccountPeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) {
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
if am.metrics != nil {
|
||||
@ -1004,11 +1016,6 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
||||
}
|
||||
}()
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err)
|
||||
return
|
||||
}
|
||||
peers := account.GetPeers()
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
|
@ -435,7 +435,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
policy, err = transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
|
||||
policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -502,8 +502,6 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, account
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups())
|
||||
}
|
||||
|
||||
return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups())
|
||||
@ -534,7 +532,7 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po
|
||||
for i, rule := range policy.Rules {
|
||||
ruleCopy := rule.Copy()
|
||||
if ruleCopy.ID == "" {
|
||||
ruleCopy.ID = xid.New().String()
|
||||
ruleCopy.ID = policy.ID // TODO: when policy can contain multiple rules, need refactor
|
||||
ruleCopy.PolicyID = policy.ID
|
||||
}
|
||||
|
||||
|
@ -9,7 +9,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
@ -32,6 +31,9 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
|
||||
|
||||
// SavePostureChecks saves a posture check.
|
||||
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -85,6 +87,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
|
||||
// DeletePostureChecks deletes a posture check by ID.
|
||||
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -267,7 +272,6 @@ func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountI
|
||||
for _, sourceGroup := range rule.Sources {
|
||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to check peer in policy source group: %v", err)
|
||||
return false, fmt.Errorf("failed to check peer in policy source group: %w", err)
|
||||
}
|
||||
|
||||
|
@ -12,9 +12,10 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -276,7 +277,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
||||
// SaveSetupKey saves the provided SetupKey to the database overriding the existing one.
|
||||
// Due to the unique nature of a SetupKey certain properties must not be overwritten
|
||||
// (e.g. the key itself, creation date, ID, etc).
|
||||
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
|
||||
// These properties are overwritten: AutoGroups, Revoked (only from false to true), and the UpdatedAt. The rest is copied from the existing key.
|
||||
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
|
||||
if keyToSave == nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
|
||||
@ -312,9 +313,12 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
return err
|
||||
}
|
||||
|
||||
// only auto groups, revoked status, and name can be updated for now
|
||||
if oldKey.Revoked && !keyToSave.Revoked {
|
||||
return status.Errorf(status.InvalidArgument, "can't un-revoke a revoked setup key")
|
||||
}
|
||||
|
||||
// only auto groups, revoked status (from false to true) can be updated
|
||||
newKey = oldKey.Copy()
|
||||
newKey.Name = keyToSave.Name
|
||||
newKey.AutoGroups = keyToSave.AutoGroups
|
||||
newKey.Revoked = keyToSave.Revoked
|
||||
newKey.UpdatedAt = time.Now().UTC()
|
||||
@ -375,7 +379,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
||||
return nil, status.NewAdminPermissionError()
|
||||
}
|
||||
|
||||
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
|
||||
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -56,11 +56,9 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
}
|
||||
|
||||
autoGroups := []string{"group_1", "group_2"}
|
||||
newKeyName := "my-new-test-key"
|
||||
revoked := true
|
||||
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
|
||||
Id: key.Id,
|
||||
Name: newKeyName,
|
||||
Revoked: revoked,
|
||||
AutoGroups: autoGroups,
|
||||
}, userID)
|
||||
@ -68,7 +66,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
|
||||
assertKey(t, newKey, keyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
|
||||
key.Id, time.Now().UTC(), autoGroups, true)
|
||||
|
||||
// check the corresponding events that should have been generated
|
||||
@ -76,7 +74,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
|
||||
assert.NotNil(t, ev)
|
||||
assert.Equal(t, account.Id, ev.AccountID)
|
||||
assert.Equal(t, newKeyName, ev.Meta["name"])
|
||||
assert.Equal(t, keyName, ev.Meta["name"])
|
||||
assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"]))
|
||||
assert.NotEmpty(t, ev.Meta["key"])
|
||||
assert.Equal(t, userID, ev.InitiatorID)
|
||||
@ -89,7 +87,6 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
autoGroups = append(autoGroups, groupAll.ID)
|
||||
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
|
||||
Id: key.Id,
|
||||
Name: newKeyName,
|
||||
Revoked: revoked,
|
||||
AutoGroups: autoGroups,
|
||||
}, userID)
|
||||
@ -213,22 +210,41 @@ func TestGetSetupKeys(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "group_1",
|
||||
Name: "group_name_1",
|
||||
Peers: []string{},
|
||||
})
|
||||
plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "group_2",
|
||||
Name: "group_name_2",
|
||||
Peers: []string{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
type testCase struct {
|
||||
name string
|
||||
keyId string
|
||||
expectedFailure bool
|
||||
}
|
||||
|
||||
testCase1 := testCase{
|
||||
name: "Should get existing Setup Key",
|
||||
keyId: plainKey.Id,
|
||||
expectedFailure: false,
|
||||
}
|
||||
testCase2 := testCase{
|
||||
name: "Should fail to get non-existent Setup Key",
|
||||
keyId: "some key",
|
||||
expectedFailure: true,
|
||||
}
|
||||
|
||||
for _, tCase := range []testCase{testCase1, testCase2} {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
key, err := manager.GetSetupKey(context.Background(), account.Id, userID, tCase.keyId)
|
||||
|
||||
if tCase.expectedFailure {
|
||||
if err == nil {
|
||||
t.Fatal("expected to fail")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
assert.NotEqual(t, plainKey.Key, key.Key)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -448,3 +464,31 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
userID := "testingUser"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// revoke the key
|
||||
updateKey := key.Copy()
|
||||
updateKey.Revoked = true
|
||||
_, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// re-activate revoked key
|
||||
updateKey.Revoked = false
|
||||
_, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID)
|
||||
assert.Error(t, err, "should not allow to update revoked key")
|
||||
|
||||
}
|
||||
|
@ -1123,6 +1123,7 @@ func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength Lock
|
||||
}
|
||||
|
||||
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
|
||||
startTime := time.Now()
|
||||
tx := s.db.Begin()
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
@ -1133,7 +1134,15 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit().Error
|
||||
|
||||
err = tx.Commit().Error
|
||||
|
||||
log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime))
|
||||
if s.metrics != nil {
|
||||
s.metrics.StoreMetrics().CountTransactionDuration(time.Since(startTime))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SqlStore) withTx(tx *gorm.DB) Store {
|
||||
@ -1279,7 +1288,7 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a
|
||||
Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete groups from store")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -1832,6 +1832,8 @@ func TestSqlStore_SavePolicy(t *testing.T) {
|
||||
|
||||
policy.Enabled = false
|
||||
policy.Description = "policy"
|
||||
policy.Rules[0].Sources = []string{"group"}
|
||||
policy.Rules[0].Ports = []string{"80", "443"}
|
||||
err = store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -13,6 +13,7 @@ type AccountManagerMetrics struct {
|
||||
updateAccountPeersDurationMs metric.Float64Histogram
|
||||
getPeerNetworkMapDurationMs metric.Float64Histogram
|
||||
networkMapObjectCount metric.Int64Histogram
|
||||
peerMetaUpdateCount metric.Int64Counter
|
||||
}
|
||||
|
||||
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics
|
||||
@ -44,11 +45,17 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter", metric.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccountManagerMetrics{
|
||||
ctx: ctx,
|
||||
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
|
||||
updateAccountPeersDurationMs: updateAccountPeersDurationMs,
|
||||
networkMapObjectCount: networkMapObjectCount,
|
||||
peerMetaUpdateCount: peerMetaUpdateCount,
|
||||
}, nil
|
||||
|
||||
}
|
||||
@ -67,3 +74,8 @@ func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration ti
|
||||
func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) {
|
||||
metrics.networkMapObjectCount.Record(metrics.ctx, count)
|
||||
}
|
||||
|
||||
// CountPeerMetUpdate counts the number of peer meta updates
|
||||
func (metrics *AccountManagerMetrics) CountPeerMetUpdate() {
|
||||
metrics.peerMetaUpdateCount.Add(metrics.ctx, 1)
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ type StoreMetrics struct {
|
||||
globalLockAcquisitionDurationMs metric.Int64Histogram
|
||||
persistenceDurationMicro metric.Int64Histogram
|
||||
persistenceDurationMs metric.Int64Histogram
|
||||
transactionDurationMs metric.Int64Histogram
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
@ -40,11 +41,17 @@ func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, er
|
||||
return nil, err
|
||||
}
|
||||
|
||||
transactionDurationMs, err := meter.Int64Histogram("management.store.transaction.duration.ms")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &StoreMetrics{
|
||||
globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro,
|
||||
globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs,
|
||||
persistenceDurationMicro: persistenceDurationMicro,
|
||||
persistenceDurationMs: persistenceDurationMs,
|
||||
transactionDurationMs: transactionDurationMs,
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
@ -60,3 +67,8 @@ func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) {
|
||||
metrics.persistenceDurationMicro.Record(metrics.ctx, duration.Microseconds())
|
||||
metrics.persistenceDurationMs.Record(metrics.ctx, duration.Milliseconds())
|
||||
}
|
||||
|
||||
// CountTransactionDuration counts the duration of a store persistence operation
|
||||
func (metrics *StoreMetrics) CountTransactionDuration(duration time.Duration) {
|
||||
metrics.transactionDurationMs.Record(metrics.ctx, duration.Milliseconds())
|
||||
}
|
||||
|
@ -805,15 +805,20 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
expiredPeers = append(expiredPeers, blockedPeers...)
|
||||
}
|
||||
|
||||
peerGroupsAdded := make(map[string][]string)
|
||||
peerGroupsRemoved := make(map[string][]string)
|
||||
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
|
||||
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
||||
// need force update all auto groups in any case they will not be duplicated
|
||||
account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
|
||||
account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
|
||||
peerGroupsAdded = account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
|
||||
peerGroupsRemoved = account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
|
||||
}
|
||||
|
||||
events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
userUpdateEvents := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
|
||||
eventsToStore = append(eventsToStore, userUpdateEvents...)
|
||||
|
||||
userGroupsEvents := am.prepareUserGroupsEvents(ctx, initiatorUser.Id, oldUser, newUser, account, peerGroupsAdded, peerGroupsRemoved)
|
||||
eventsToStore = append(eventsToStore, userGroupsEvents...)
|
||||
|
||||
updatedUserInfo, err := getUserInfo(ctx, am, newUser, account)
|
||||
if err != nil {
|
||||
@ -872,32 +877,78 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in
|
||||
})
|
||||
}
|
||||
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() {
|
||||
var eventsToStore []func()
|
||||
if newUser.AutoGroups != nil {
|
||||
removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups)
|
||||
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
|
||||
for _, g := range removedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
})
|
||||
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
}
|
||||
}
|
||||
for _, g := range addedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
})
|
||||
}
|
||||
removedEvents := am.handleGroupRemovedFromUser(ctx, initiatorUserID, oldUser, newUser, account, removedGroups, peerGroupsRemoved)
|
||||
eventsToStore = append(eventsToStore, removedEvents...)
|
||||
|
||||
addedEvents := am.handleGroupAddedToUser(ctx, initiatorUserID, oldUser, newUser, account, addedGroups, peerGroupsAdded)
|
||||
eventsToStore = append(eventsToStore, addedEvents...)
|
||||
}
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() {
|
||||
var eventsToStore []func()
|
||||
for _, g := range addedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
})
|
||||
}
|
||||
}
|
||||
for groupID, peerIDs := range peerGroupsAdded {
|
||||
group := account.GetGroup(groupID)
|
||||
for _, peerID := range peerIDs {
|
||||
peer := account.GetPeer(peerID)
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{
|
||||
"group": group.Name, "group_id": group.ID,
|
||||
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
}
|
||||
am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupAddedToPeer, meta)
|
||||
})
|
||||
}
|
||||
}
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() {
|
||||
var eventsToStore []func()
|
||||
for _, g := range removedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
})
|
||||
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
}
|
||||
}
|
||||
for groupID, peerIDs := range peerGroupsRemoved {
|
||||
group := account.GetGroup(groupID)
|
||||
for _, peerID := range peerIDs {
|
||||
peer := account.GetPeer(peerID)
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{
|
||||
"group": group.Name, "group_id": group.ID,
|
||||
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
}
|
||||
am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupRemovedFromPeer, meta)
|
||||
})
|
||||
}
|
||||
}
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
|
@ -140,7 +140,7 @@ type Client struct {
|
||||
instanceURL *RelayAddr
|
||||
muInstanceURL sync.Mutex
|
||||
|
||||
onDisconnectListener func()
|
||||
onDisconnectListener func(string)
|
||||
onConnectedListener func()
|
||||
listenerMutex sync.Mutex
|
||||
}
|
||||
@ -233,7 +233,7 @@ func (c *Client) ServerInstanceURL() (string, error) {
|
||||
}
|
||||
|
||||
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
|
||||
func (c *Client) SetOnDisconnectListener(fn func()) {
|
||||
func (c *Client) SetOnDisconnectListener(fn func(string)) {
|
||||
c.listenerMutex.Lock()
|
||||
defer c.listenerMutex.Unlock()
|
||||
c.onDisconnectListener = fn
|
||||
@ -554,7 +554,7 @@ func (c *Client) notifyDisconnected() {
|
||||
if c.onDisconnectListener == nil {
|
||||
return
|
||||
}
|
||||
go c.onDisconnectListener()
|
||||
go c.onDisconnectListener(c.connectionURL)
|
||||
}
|
||||
|
||||
func (c *Client) notifyConnected() {
|
||||
|
@ -551,7 +551,7 @@ func TestCloseByServer(t *testing.T) {
|
||||
}
|
||||
|
||||
disconnected := make(chan struct{})
|
||||
relayClient.SetOnDisconnectListener(func() {
|
||||
relayClient.SetOnDisconnectListener(func(_ string) {
|
||||
log.Infof("client disconnected")
|
||||
close(disconnected)
|
||||
})
|
||||
|
@ -4,65 +4,120 @@ import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
reconnectingTimeout = 5 * time.Second
|
||||
reconnectingTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
// Guard manage the reconnection tries to the Relay server in case of disconnection event.
|
||||
type Guard struct {
|
||||
ctx context.Context
|
||||
relayClient *Client
|
||||
// OnNewRelayClient is a channel that is used to notify the relay client about a new relay client instance.
|
||||
OnNewRelayClient chan *Client
|
||||
serverPicker *ServerPicker
|
||||
}
|
||||
|
||||
// NewGuard creates a new guard for the relay client.
|
||||
func NewGuard(context context.Context, relayClient *Client) *Guard {
|
||||
func NewGuard(sp *ServerPicker) *Guard {
|
||||
g := &Guard{
|
||||
ctx: context,
|
||||
relayClient: relayClient,
|
||||
OnNewRelayClient: make(chan *Client, 1),
|
||||
serverPicker: sp,
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
// OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection
|
||||
// StartReconnectTrys is called when the relay client is disconnected from the relay server.
|
||||
// It attempts to reconnect to the relay server. The function first tries a quick reconnect
|
||||
// to the same server that was used before, if the server URL is still valid. If the quick
|
||||
// reconnect fails, it starts a ticker to periodically attempt server picking until it
|
||||
// succeeds or the context is done.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context to control the lifecycle of the reconnection attempts.
|
||||
// - relayClient: The relay client instance that was disconnected.
|
||||
// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent
|
||||
func (g *Guard) OnDisconnected() {
|
||||
if g.quickReconnect() {
|
||||
func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) {
|
||||
if relayClient == nil {
|
||||
goto RETRY
|
||||
}
|
||||
if g.isServerURLStillValid(relayClient) && g.quickReconnect(ctx, relayClient) {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(reconnectingTimeout)
|
||||
RETRY:
|
||||
ticker := exponentTicker(ctx)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err := g.relayClient.Connect()
|
||||
if err != nil {
|
||||
log.Errorf("failed to reconnect to relay server: %s", err)
|
||||
if err := g.retry(ctx); err != nil {
|
||||
log.Errorf("failed to pick new Relay server: %s", err)
|
||||
continue
|
||||
}
|
||||
return
|
||||
case <-g.ctx.Done():
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guard) quickReconnect() bool {
|
||||
ctx, cancel := context.WithTimeout(g.ctx, 1500*time.Millisecond)
|
||||
func (g *Guard) retry(ctx context.Context) error {
|
||||
log.Infof("try to pick up a new Relay server")
|
||||
relayClient, err := g.serverPicker.PickServer(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// prevent to work with a deprecated Relay client instance
|
||||
g.drainRelayClientChan()
|
||||
|
||||
g.OnNewRelayClient <- relayClient
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Guard) quickReconnect(parentCtx context.Context, rc *Client) bool {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 1500*time.Millisecond)
|
||||
defer cancel()
|
||||
<-ctx.Done()
|
||||
|
||||
if g.ctx.Err() != nil {
|
||||
if parentCtx.Err() != nil {
|
||||
return false
|
||||
}
|
||||
log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
|
||||
|
||||
if err := g.relayClient.Connect(); err != nil {
|
||||
if err := rc.Connect(); err != nil {
|
||||
log.Errorf("failed to reconnect to relay server: %s", err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (g *Guard) drainRelayClientChan() {
|
||||
select {
|
||||
case <-g.OnNewRelayClient:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guard) isServerURLStillValid(rc *Client) bool {
|
||||
for _, url := range g.serverPicker.ServerURLs.Load().([]string) {
|
||||
if url == rc.connectionURL {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func exponentTicker(ctx context.Context) *backoff.Ticker {
|
||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 2 * time.Second,
|
||||
Multiplier: 2,
|
||||
MaxInterval: reconnectingTimeout,
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
|
||||
return backoff.NewTicker(bo)
|
||||
}
|
||||
|
@ -57,12 +57,15 @@ type ManagerService interface {
|
||||
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
|
||||
// unused relay connection and close it.
|
||||
type Manager struct {
|
||||
ctx context.Context
|
||||
serverURLs []string
|
||||
peerID string
|
||||
tokenStore *relayAuth.TokenStore
|
||||
ctx context.Context
|
||||
peerID string
|
||||
running bool
|
||||
tokenStore *relayAuth.TokenStore
|
||||
serverPicker *ServerPicker
|
||||
|
||||
relayClient *Client
|
||||
relayClient *Client
|
||||
// the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable
|
||||
relayClientMu sync.Mutex
|
||||
reconnectGuard *Guard
|
||||
|
||||
relayClients map[string]*RelayTrack
|
||||
@ -76,48 +79,54 @@ type Manager struct {
|
||||
// NewManager creates a new manager instance.
|
||||
// The serverURL address can be empty. In this case, the manager will not serve.
|
||||
func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager {
|
||||
return &Manager{
|
||||
ctx: ctx,
|
||||
serverURLs: serverURLs,
|
||||
peerID: peerID,
|
||||
tokenStore: &relayAuth.TokenStore{},
|
||||
tokenStore := &relayAuth.TokenStore{}
|
||||
|
||||
m := &Manager{
|
||||
ctx: ctx,
|
||||
peerID: peerID,
|
||||
tokenStore: tokenStore,
|
||||
serverPicker: &ServerPicker{
|
||||
TokenStore: tokenStore,
|
||||
PeerID: peerID,
|
||||
},
|
||||
relayClients: make(map[string]*RelayTrack),
|
||||
onDisconnectedListeners: make(map[string]*list.List),
|
||||
}
|
||||
m.serverPicker.ServerURLs.Store(serverURLs)
|
||||
m.reconnectGuard = NewGuard(m.serverPicker)
|
||||
return m
|
||||
}
|
||||
|
||||
// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop for
|
||||
// the unused relay connections. The manager will automatically reconnect to the relay server in case of disconnection.
|
||||
// Serve starts the manager, attempting to establish a connection with the relay server.
|
||||
// If the connection fails, it will keep trying to reconnect in the background.
|
||||
// Additionally, it starts a cleanup loop to remove unused relay connections.
|
||||
// The manager will automatically reconnect to the relay server in case of disconnection.
|
||||
func (m *Manager) Serve() error {
|
||||
if m.relayClient != nil {
|
||||
if m.running {
|
||||
return fmt.Errorf("manager already serving")
|
||||
}
|
||||
log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
|
||||
m.running = true
|
||||
log.Debugf("starting relay client manager with %v relay servers", m.serverPicker.ServerURLs.Load())
|
||||
|
||||
sp := ServerPicker{
|
||||
TokenStore: m.tokenStore,
|
||||
PeerID: m.peerID,
|
||||
}
|
||||
|
||||
client, err := sp.PickServer(m.ctx, m.serverURLs)
|
||||
client, err := m.serverPicker.PickServer(m.ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
go m.reconnectGuard.StartReconnectTrys(m.ctx, nil)
|
||||
} else {
|
||||
m.storeClient(client)
|
||||
}
|
||||
m.relayClient = client
|
||||
|
||||
m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
|
||||
m.relayClient.SetOnConnectedListener(m.onServerConnected)
|
||||
m.relayClient.SetOnDisconnectListener(func() {
|
||||
m.onServerDisconnected(client.connectionURL)
|
||||
})
|
||||
m.startCleanupLoop()
|
||||
return nil
|
||||
go m.listenGuardEvent(m.ctx)
|
||||
go m.startCleanupLoop()
|
||||
return err
|
||||
}
|
||||
|
||||
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
||||
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
|
||||
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
|
||||
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
|
||||
m.relayClientMu.Lock()
|
||||
defer m.relayClientMu.Unlock()
|
||||
|
||||
if m.relayClient == nil {
|
||||
return nil, ErrRelayClientNotConnected
|
||||
}
|
||||
@ -146,6 +155,9 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
|
||||
|
||||
// Ready returns true if the home Relay client is connected to the relay server.
|
||||
func (m *Manager) Ready() bool {
|
||||
m.relayClientMu.Lock()
|
||||
defer m.relayClientMu.Unlock()
|
||||
|
||||
if m.relayClient == nil {
|
||||
return false
|
||||
}
|
||||
@ -159,6 +171,13 @@ func (m *Manager) SetOnReconnectedListener(f func()) {
|
||||
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
|
||||
// closed.
|
||||
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
|
||||
m.relayClientMu.Lock()
|
||||
defer m.relayClientMu.Unlock()
|
||||
|
||||
if m.relayClient == nil {
|
||||
return ErrRelayClientNotConnected
|
||||
}
|
||||
|
||||
foreign, err := m.isForeignServer(serverAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -177,6 +196,9 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ
|
||||
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
|
||||
// lost. This address will be sent to the target peer to choose the common relay server for the communication.
|
||||
func (m *Manager) RelayInstanceAddress() (string, error) {
|
||||
m.relayClientMu.Lock()
|
||||
defer m.relayClientMu.Unlock()
|
||||
|
||||
if m.relayClient == nil {
|
||||
return "", ErrRelayClientNotConnected
|
||||
}
|
||||
@ -185,13 +207,18 @@ func (m *Manager) RelayInstanceAddress() (string, error) {
|
||||
|
||||
// ServerURLs returns the addresses of the relay servers.
|
||||
func (m *Manager) ServerURLs() []string {
|
||||
return m.serverURLs
|
||||
return m.serverPicker.ServerURLs.Load().([]string)
|
||||
}
|
||||
|
||||
// HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with
|
||||
// Relay service.
|
||||
func (m *Manager) HasRelayAddress() bool {
|
||||
return len(m.serverURLs) > 0
|
||||
return len(m.serverPicker.ServerURLs.Load().([]string)) > 0
|
||||
}
|
||||
|
||||
func (m *Manager) UpdateServerURLs(serverURLs []string) {
|
||||
log.Infof("update relay server URLs: %v", serverURLs)
|
||||
m.serverPicker.ServerURLs.Store(serverURLs)
|
||||
}
|
||||
|
||||
// UpdateToken updates the token in the token store.
|
||||
@ -245,9 +272,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
// if connection closed then delete the relay client from the list
|
||||
relayClient.SetOnDisconnectListener(func() {
|
||||
m.onServerDisconnected(serverAddress)
|
||||
})
|
||||
relayClient.SetOnDisconnectListener(m.onServerDisconnected)
|
||||
rt.relayClient = relayClient
|
||||
rt.Unlock()
|
||||
|
||||
@ -265,14 +290,37 @@ func (m *Manager) onServerConnected() {
|
||||
go m.onReconnectedListenerFn()
|
||||
}
|
||||
|
||||
// onServerDisconnected start to reconnection for home server only
|
||||
func (m *Manager) onServerDisconnected(serverAddress string) {
|
||||
m.relayClientMu.Lock()
|
||||
if serverAddress == m.relayClient.connectionURL {
|
||||
go m.reconnectGuard.OnDisconnected()
|
||||
go m.reconnectGuard.StartReconnectTrys(m.ctx, m.relayClient)
|
||||
}
|
||||
m.relayClientMu.Unlock()
|
||||
|
||||
m.notifyOnDisconnectListeners(serverAddress)
|
||||
}
|
||||
|
||||
func (m *Manager) listenGuardEvent(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case rc := <-m.reconnectGuard.OnNewRelayClient:
|
||||
m.storeClient(rc)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) storeClient(client *Client) {
|
||||
m.relayClientMu.Lock()
|
||||
defer m.relayClientMu.Unlock()
|
||||
|
||||
m.relayClient = client
|
||||
m.relayClient.SetOnConnectedListener(m.onServerConnected)
|
||||
m.relayClient.SetOnDisconnectListener(m.onServerDisconnected)
|
||||
}
|
||||
|
||||
func (m *Manager) isForeignServer(address string) (bool, error) {
|
||||
rAddr, err := m.relayClient.ServerInstanceURL()
|
||||
if err != nil {
|
||||
@ -282,22 +330,16 @@ func (m *Manager) isForeignServer(address string) (bool, error) {
|
||||
}
|
||||
|
||||
func (m *Manager) startCleanupLoop() {
|
||||
if m.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(relayCleanupInterval)
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.cleanUpUnusedRelays()
|
||||
}
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.cleanUpUnusedRelays()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) cleanUpUnusedRelays() {
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@ -12,10 +13,13 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
connectionTimeout = 30 * time.Second
|
||||
maxConcurrentServers = 7
|
||||
)
|
||||
|
||||
var (
|
||||
connectionTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type connResult struct {
|
||||
RelayClient *Client
|
||||
Url string
|
||||
@ -24,20 +28,22 @@ type connResult struct {
|
||||
|
||||
type ServerPicker struct {
|
||||
TokenStore *auth.TokenStore
|
||||
ServerURLs atomic.Value
|
||||
PeerID string
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) {
|
||||
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
totalServers := len(urls)
|
||||
totalServers := len(sp.ServerURLs.Load().([]string))
|
||||
|
||||
connResultChan := make(chan connResult, totalServers)
|
||||
successChan := make(chan connResult, 1)
|
||||
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
|
||||
|
||||
for _, url := range urls {
|
||||
log.Debugf("pick server from list: %v", sp.ServerURLs.Load().([]string))
|
||||
for _, url := range sp.ServerURLs.Load().([]string) {
|
||||
// todo check if we have a successful connection so we do not need to connect to other servers
|
||||
concurrentLimiter <- struct{}{}
|
||||
go func(url string) {
|
||||
@ -78,7 +84,7 @@ func (sp *ServerPicker) processConnResults(resultChan chan connResult, successCh
|
||||
for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
|
||||
cr := <-resultChan
|
||||
if cr.Err != nil {
|
||||
log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
|
||||
log.Tracef("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
|
||||
continue
|
||||
}
|
||||
log.Infof("connected to Relay server: %s", cr.Url)
|
||||
|
@ -4,19 +4,23 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestServerPicker_UnavailableServers(t *testing.T) {
|
||||
connectionTimeout = 5 * time.Second
|
||||
|
||||
sp := ServerPicker{
|
||||
TokenStore: nil,
|
||||
PeerID: "test",
|
||||
}
|
||||
sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
_, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"})
|
||||
_, err := sp.PickServer(ctx)
|
||||
if err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
@ -16,6 +16,8 @@ import (
|
||||
|
||||
const (
|
||||
bufferSize = 8820
|
||||
|
||||
errCloseConn = "failed to close connection to peer: %s"
|
||||
)
|
||||
|
||||
// Peer represents a peer connection
|
||||
@ -46,6 +48,12 @@ func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *
|
||||
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
|
||||
// the message accordingly.
|
||||
func (p *Peer) Work() {
|
||||
defer func() {
|
||||
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
p.log.Errorf(errCloseConn, err)
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
@ -97,7 +105,7 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
|
||||
case messages.MsgTypeClose:
|
||||
p.log.Infof("peer exited gracefully")
|
||||
if err := p.conn.Close(); err != nil {
|
||||
log.Errorf("failed to close connection to peer: %s", err)
|
||||
log.Errorf(errCloseConn, err)
|
||||
}
|
||||
default:
|
||||
p.log.Warnf("received unexpected message type: %s", msgType)
|
||||
@ -121,9 +129,8 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
|
||||
p.log.Errorf("failed to send close message to peer: %s", p.String())
|
||||
}
|
||||
|
||||
err = p.conn.Close()
|
||||
if err != nil {
|
||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||
if err := p.conn.Close(); err != nil {
|
||||
p.log.Errorf(errCloseConn, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -132,7 +139,7 @@ func (p *Peer) Close() {
|
||||
defer p.connMu.Unlock()
|
||||
|
||||
if err := p.conn.Close(); err != nil {
|
||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
||||
p.log.Errorf(errCloseConn, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
68
util/file.go
68
util/file.go
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@ -14,8 +15,21 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func WriteBytesWithRestrictedPermission(ctx context.Context, file string, bs []byte) error {
|
||||
configDir, configFileName, err := prepareConfigFileDir(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("prepare config file dir: %w", err)
|
||||
}
|
||||
|
||||
if err = EnforcePermission(file); err != nil {
|
||||
return fmt.Errorf("enforce permission: %w", err)
|
||||
}
|
||||
|
||||
return writeBytes(ctx, file, err, configDir, configFileName, bs)
|
||||
}
|
||||
|
||||
// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
|
||||
func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
|
||||
func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error {
|
||||
configDir, configFileName, err := prepareConfigFileDir(file)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -26,18 +40,18 @@ func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeJson(file, obj, configDir, configFileName)
|
||||
return writeJson(ctx, file, obj, configDir, configFileName)
|
||||
}
|
||||
|
||||
// WriteJson writes JSON config object to a file creating parent directories if required
|
||||
// The output JSON is pretty-formatted
|
||||
func WriteJson(file string, obj interface{}) error {
|
||||
func WriteJson(ctx context.Context, file string, obj interface{}) error {
|
||||
configDir, configFileName, err := prepareConfigFileDir(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeJson(file, obj, configDir, configFileName)
|
||||
return writeJson(ctx, file, obj, configDir, configFileName)
|
||||
}
|
||||
|
||||
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
|
||||
@ -79,24 +93,47 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeJson(file string, obj interface{}, configDir string, configFileName string) error {
|
||||
func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error {
|
||||
// Check context before expensive operations
|
||||
if ctx.Err() != nil {
|
||||
return fmt.Errorf("write json start: %w", ctx.Err())
|
||||
}
|
||||
|
||||
// make it pretty
|
||||
bs, err := json.MarshalIndent(obj, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("marshal: %w", err)
|
||||
}
|
||||
|
||||
return writeBytes(ctx, file, err, configDir, configFileName, bs)
|
||||
}
|
||||
|
||||
func writeBytes(ctx context.Context, file string, err error, configDir string, configFileName string, bs []byte) error {
|
||||
if ctx.Err() != nil {
|
||||
return fmt.Errorf("write bytes start: %w", ctx.Err())
|
||||
}
|
||||
|
||||
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("create temp: %w", err)
|
||||
}
|
||||
|
||||
tempFileName := tempFile.Name()
|
||||
// closing file ops as windows doesn't allow to move it
|
||||
err = tempFile.Close()
|
||||
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) {
|
||||
log.Warnf("failed to set deadline: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = tempFile.Write(bs)
|
||||
if err != nil {
|
||||
return err
|
||||
_ = tempFile.Close()
|
||||
return fmt.Errorf("write: %w", err)
|
||||
}
|
||||
|
||||
if err = tempFile.Close(); err != nil {
|
||||
return fmt.Errorf("close %s: %w", tempFileName, err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
@ -106,14 +143,13 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st
|
||||
}
|
||||
}()
|
||||
|
||||
err = os.WriteFile(tempFileName, bs, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
// Check context again
|
||||
if ctx.Err() != nil {
|
||||
return fmt.Errorf("after temp file: %w", ctx.Err())
|
||||
}
|
||||
|
||||
err = os.Rename(tempFileName, file)
|
||||
if err != nil {
|
||||
return err
|
||||
if err = os.Rename(tempFileName, file); err != nil {
|
||||
return fmt.Errorf("move %s to %s: %w", tempFileName, file, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -1,6 +1,7 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
@ -39,7 +40,7 @@ func TestConfigJSON(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
err := WriteJson(tmpDir+"/testconfig.json", tt.config)
|
||||
err := WriteJson(context.Background(), tmpDir+"/testconfig.json", tt.config)
|
||||
require.NoError(t, err)
|
||||
|
||||
read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
|
||||
@ -73,7 +74,7 @@ func TestCopyFileContents(t *testing.T) {
|
||||
src := tmpDir + "/copytest_src"
|
||||
dst := tmpDir + "/copytest_dst"
|
||||
|
||||
err := WriteJson(src, tt.srcContent)
|
||||
err := WriteJson(context.Background(), src, tt.srcContent)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = CopyFileContents(src, dst)
|
||||
@ -127,7 +128,7 @@ func TestHandleConfigFileWithoutFullPath(t *testing.T) {
|
||||
_ = os.Remove(cfgFile)
|
||||
}()
|
||||
|
||||
err := WriteJson(cfgFile, tt.config)
|
||||
err := WriteJson(context.Background(), cfgFile, tt.config)
|
||||
require.NoError(t, err)
|
||||
|
||||
read, err := ReadJson(cfgFile, &TestConfig{})
|
||||
|
@ -3,6 +3,9 @@ package grpc
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"net"
|
||||
"os/user"
|
||||
"runtime"
|
||||
@ -23,20 +26,22 @@ func WithCustomDialer() grpc.DialOption {
|
||||
if runtime.GOOS == "linux" {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to get current user: %v", err)
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
// the custom dialer requires root permissions which are not required for use cases run as non-root
|
||||
if currentUser.Uid != "0" {
|
||||
log.Debug("Not running as root, using standard dialer")
|
||||
dialer := &net.Dialer{}
|
||||
return dialer.DialContext(ctx, "tcp", addr)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Using nbnet.NewDialer()")
|
||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to dial: %s", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
})
|
||||
|
@ -69,7 +69,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
|
||||
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial: %w", err)
|
||||
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
|
||||
}
|
||||
|
||||
// Wrap the connection in Conn to handle Close with hooks
|
||||
|
@ -4,9 +4,14 @@ package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK"
|
||||
|
||||
// SetSocketMark sets the SO_MARK option on the given socket connection
|
||||
func SetSocketMark(conn syscall.Conn) error {
|
||||
sysconn, err := conn.SyscallConn()
|
||||
@ -36,6 +41,13 @@ func SetRawSocketMark(conn syscall.RawConn) error {
|
||||
|
||||
func SetSocketOpt(fd int) error {
|
||||
if CustomRoutingDisabled() {
|
||||
log.Infof("Custom routing is disabled, skipping SO_MARK")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check for the new environment variable
|
||||
if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" {
|
||||
log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user