diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 524f35f6f..d6adcb27a 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -16,7 +16,7 @@ jobs: matrix: arch: [ '386','amd64' ] store: [ 'sqlite', 'postgres'] - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Install Go uses: actions/setup-go@v5 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7af6d3e4d..b2e2437e6 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,7 +20,7 @@ concurrency: jobs: release: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 env: flags: "" steps: diff --git a/README.md b/README.md index aa3ec41e5..270c9ad87 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,8 @@ ![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab) +### NetBird on Lawrence Systems (Video) +[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) ### Key features @@ -62,6 +64,7 @@ | | | | | | | | |
  • - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)
  • | | | | | | | | | + ### Quickstart with NetBird Cloud - Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install) diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index f0dc8bf21..033d1bb6a 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -3,7 +3,6 @@ package cmd import ( "context" "net" - "path/filepath" "testing" "time" @@ -34,18 +33,12 @@ func startTestingServices(t *testing.T) string { if err != nil { t.Fatal(err) } - testDir := t.TempDir() - config.Datadir = testDir - err = util.CopyFileContents("../testdata/store.json", filepath.Join(testDir, "store.json")) - if err != nil { - t.Fatal(err) - } _, signalLis := startSignal(t) signalAddr := signalLis.Addr().String() config.Signal.URI = signalAddr - _, mgmLis := startManagement(t, config) + _, mgmLis := startManagement(t, config, "../testdata/store.sqlite") mgmAddr := mgmLis.Addr().String() return mgmAddr } @@ -70,7 +63,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) { return s, lis } -func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) { +func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) { t.Helper() lis, err := net.Listen("tcp", ":0") @@ -78,7 +71,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir) + store, cleanUp, err := mgmt.NewTestStoreFromSqlite(context.Background(), testFile, t.TempDir()) if err != nil { t.Fatal(err) } diff --git a/client/internal/connect.go b/client/internal/connect.go index c77f95603..74dc1f1b5 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -269,12 +269,6 @@ func (c *ConnectClient) run( checks := loginResp.GetChecks() c.engineMutex.Lock() - if c.engine != nil && c.engine.ctx.Err() != nil { - log.Info("Stopping Netbird Engine") - if err := c.engine.Stop(); err != nil { - log.Errorf("Failed to stop engine: %v", err) - } - } c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks) c.engineMutex.Unlock() @@ -294,6 +288,15 @@ func (c *ConnectClient) run( } <-engineCtx.Done() + c.engineMutex.Lock() + if c.engine != nil && c.engine.wgInterface != nil { + log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name()) + if err := c.engine.Stop(); err != nil { + log.Errorf("Failed to stop engine: %v", err) + } + c.engine = nil + } + c.engineMutex.Unlock() c.statusRecorder.ClientTeardown() backOff.Reset() diff --git a/client/internal/engine.go b/client/internal/engine.go index c51901a22..eac8ec098 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -251,6 +251,13 @@ func (e *Engine) Stop() error { } log.Info("Network monitor: stopped") + // stop/restore DNS first so dbus and friends don't complain because of a missing interface + e.stopDNSServer() + + if e.routeManager != nil { + e.routeManager.Stop() + } + err := e.removeAllPeers() if err != nil { return fmt.Errorf("failed to remove all peers: %s", err) @@ -1116,18 +1123,12 @@ func (e *Engine) close() { } } - // stop/restore DNS first so dbus and friends don't complain because of a missing interface - e.stopDNSServer() - - if e.routeManager != nil { - e.routeManager.Stop() - } - log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) if e.wgInterface != nil { if err := e.wgInterface.Close(); err != nil { log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err) } + e.wgInterface = nil } if !isNil(e.sshServer) { @@ -1395,7 +1396,7 @@ func (e *Engine) startNetworkMonitor() { } // Set a new timer to debounce rapid network changes - debounceTimer = time.AfterFunc(1*time.Second, func() { + debounceTimer = time.AfterFunc(2*time.Second, func() { // This function is called after the debounce period mu.Lock() defer mu.Unlock() @@ -1426,6 +1427,11 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) { } func (e *Engine) stopDNSServer() { + if e.dnsServer == nil { + return + } + e.dnsServer.Stop() + e.dnsServer = nil err := fmt.Errorf("DNS server stopped") nsGroupStates := e.statusRecorder.GetDNSStates() for i := range nsGroupStates { @@ -1433,10 +1439,6 @@ func (e *Engine) stopDNSServer() { nsGroupStates[i].Error = err } e.statusRecorder.UpdateDNSStates(nsGroupStates) - if e.dnsServer != nil { - e.dnsServer.Stop() - e.dnsServer = nil - } } // isChecksEqual checks if two slices of checks are equal. diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 29a8439a2..3d1983c6b 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -6,7 +6,6 @@ import ( "net" "net/netip" "os" - "path/filepath" "runtime" "strings" "sync" @@ -824,20 +823,6 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { func TestEngine_MultiplePeers(t *testing.T) { // log.SetLevel(log.DebugLevel) - dir := t.TempDir() - - err := util.CopyFileContents("../testdata/store.json", filepath.Join(dir, "store.json")) - if err != nil { - t.Fatal(err) - } - defer func() { - err = os.Remove(filepath.Join(dir, "store.json")) //nolint - if err != nil { - t.Fatal(err) - return - } - }() - ctx, cancel := context.WithCancel(CtxInitState(context.Background())) defer cancel() @@ -847,7 +832,7 @@ func TestEngine_MultiplePeers(t *testing.T) { return } defer sigServer.Stop() - mgmtServer, mgmtAddr, err := startManagement(t, dir) + mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sqlite") if err != nil { t.Fatal(err) return @@ -1070,7 +1055,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) { return s, lis.Addr().String(), nil } -func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) { +func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) { t.Helper() config := &server.Config{ @@ -1095,7 +1080,7 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir) + store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), testFile, config.Datadir) if err != nil { return nil, "", err } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index ad84bd700..0d4ad2396 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -32,6 +32,8 @@ const ( connPriorityRelay ConnPriority = 1 connPriorityICETurn ConnPriority = 1 connPriorityICEP2P ConnPriority = 2 + + reconnectMaxElapsedTime = 30 * time.Minute ) type WgConfig struct { @@ -83,6 +85,7 @@ type Conn struct { wgProxyICE wgproxy.Proxy wgProxyRelay wgproxy.Proxy signaler *Signaler + iFaceDiscover stdnet.ExternalIFaceDiscover relayManager *relayClient.Manager allowedIPsIP string handshaker *Handshaker @@ -108,6 +111,8 @@ type Conn struct { // for reconnection operations iCEDisconnected chan bool relayDisconnected chan bool + connMonitor *ConnMonitor + reconnectCh <-chan struct{} } // NewConn creates a new not opened Conn to the remote peer. @@ -123,21 +128,31 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu connLog := log.WithField("peer", config.Key) var conn = &Conn{ - log: connLog, - ctx: ctx, - ctxCancel: ctxCancel, - config: config, - statusRecorder: statusRecorder, - wgProxyFactory: wgProxyFactory, - signaler: signaler, - relayManager: relayManager, - allowedIPsIP: allowedIPsIP.String(), - statusRelay: NewAtomicConnStatus(), - statusICE: NewAtomicConnStatus(), + log: connLog, + ctx: ctx, + ctxCancel: ctxCancel, + config: config, + statusRecorder: statusRecorder, + wgProxyFactory: wgProxyFactory, + signaler: signaler, + iFaceDiscover: iFaceDiscover, + relayManager: relayManager, + allowedIPsIP: allowedIPsIP.String(), + statusRelay: NewAtomicConnStatus(), + statusICE: NewAtomicConnStatus(), + iCEDisconnected: make(chan bool, 1), relayDisconnected: make(chan bool, 1), } + conn.connMonitor, conn.reconnectCh = NewConnMonitor( + signaler, + iFaceDiscover, + config, + conn.relayDisconnected, + conn.iCEDisconnected, + ) + rFns := WorkerRelayCallbacks{ OnConnReady: conn.relayConnectionIsReady, OnDisconnected: conn.onWorkerRelayStateDisconnected, @@ -200,6 +215,8 @@ func (conn *Conn) startHandshakeAndReconnect() { conn.log.Errorf("failed to send initial offer: %v", err) } + go conn.connMonitor.Start(conn.ctx) + if conn.workerRelay.IsController() { conn.reconnectLoopWithRetry() } else { @@ -309,12 +326,14 @@ func (conn *Conn) reconnectLoopWithRetry() { // With it, we can decrease to send necessary offer select { case <-conn.ctx.Done(): + return case <-time.After(3 * time.Second): } ticker := conn.prepareExponentTicker() defer ticker.Stop() time.Sleep(1 * time.Second) + for { select { case t := <-ticker.C: @@ -342,20 +361,11 @@ func (conn *Conn) reconnectLoopWithRetry() { if err != nil { conn.log.Errorf("failed to do handshake: %v", err) } - case changed := <-conn.relayDisconnected: - if !changed { - continue - } - conn.log.Debugf("Relay state changed, reset reconnect timer") - ticker.Stop() - ticker = conn.prepareExponentTicker() - case changed := <-conn.iCEDisconnected: - if !changed { - continue - } - conn.log.Debugf("ICE state changed, reset reconnect timer") + + case <-conn.reconnectCh: ticker.Stop() ticker = conn.prepareExponentTicker() + case <-conn.ctx.Done(): conn.log.Debugf("context is done, stop reconnect loop") return @@ -366,10 +376,10 @@ func (conn *Conn) reconnectLoopWithRetry() { func (conn *Conn) prepareExponentTicker() *backoff.Ticker { bo := backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 800 * time.Millisecond, - RandomizationFactor: 0.01, + RandomizationFactor: 0.1, Multiplier: 2, MaxInterval: conn.config.Timeout, - MaxElapsedTime: 0, + MaxElapsedTime: reconnectMaxElapsedTime, Stop: backoff.Stop, Clock: backoff.SystemClock, }, conn.ctx) diff --git a/client/internal/peer/conn_monitor.go b/client/internal/peer/conn_monitor.go new file mode 100644 index 000000000..75722c990 --- /dev/null +++ b/client/internal/peer/conn_monitor.go @@ -0,0 +1,212 @@ +package peer + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/pion/ice/v3" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +const ( + signalerMonitorPeriod = 5 * time.Second + candidatesMonitorPeriod = 5 * time.Minute + candidateGatheringTimeout = 5 * time.Second +) + +type ConnMonitor struct { + signaler *Signaler + iFaceDiscover stdnet.ExternalIFaceDiscover + config ConnConfig + relayDisconnected chan bool + iCEDisconnected chan bool + reconnectCh chan struct{} + currentCandidates []ice.Candidate + candidatesMu sync.Mutex +} + +func NewConnMonitor(signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, config ConnConfig, relayDisconnected, iCEDisconnected chan bool) (*ConnMonitor, <-chan struct{}) { + reconnectCh := make(chan struct{}, 1) + cm := &ConnMonitor{ + signaler: signaler, + iFaceDiscover: iFaceDiscover, + config: config, + relayDisconnected: relayDisconnected, + iCEDisconnected: iCEDisconnected, + reconnectCh: reconnectCh, + } + return cm, reconnectCh +} + +func (cm *ConnMonitor) Start(ctx context.Context) { + signalerReady := make(chan struct{}, 1) + go cm.monitorSignalerReady(ctx, signalerReady) + + localCandidatesChanged := make(chan struct{}, 1) + go cm.monitorLocalCandidatesChanged(ctx, localCandidatesChanged) + + for { + select { + case changed := <-cm.relayDisconnected: + if !changed { + continue + } + log.Debugf("Relay state changed, triggering reconnect") + cm.triggerReconnect() + + case changed := <-cm.iCEDisconnected: + if !changed { + continue + } + log.Debugf("ICE state changed, triggering reconnect") + cm.triggerReconnect() + + case <-signalerReady: + log.Debugf("Signaler became ready, triggering reconnect") + cm.triggerReconnect() + + case <-localCandidatesChanged: + log.Debugf("Local candidates changed, triggering reconnect") + cm.triggerReconnect() + + case <-ctx.Done(): + return + } + } +} + +func (cm *ConnMonitor) monitorSignalerReady(ctx context.Context, signalerReady chan<- struct{}) { + if cm.signaler == nil { + return + } + + ticker := time.NewTicker(signalerMonitorPeriod) + defer ticker.Stop() + + lastReady := true + for { + select { + case <-ticker.C: + currentReady := cm.signaler.Ready() + if !lastReady && currentReady { + select { + case signalerReady <- struct{}{}: + default: + } + } + lastReady = currentReady + case <-ctx.Done(): + return + } + } +} + +func (cm *ConnMonitor) monitorLocalCandidatesChanged(ctx context.Context, localCandidatesChanged chan<- struct{}) { + ufrag, pwd, err := generateICECredentials() + if err != nil { + log.Warnf("Failed to generate ICE credentials: %v", err) + return + } + + ticker := time.NewTicker(candidatesMonitorPeriod) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := cm.handleCandidateTick(ctx, localCandidatesChanged, ufrag, pwd); err != nil { + log.Warnf("Failed to handle candidate tick: %v", err) + } + case <-ctx.Done(): + return + } + } +} + +func (cm *ConnMonitor) handleCandidateTick(ctx context.Context, localCandidatesChanged chan<- struct{}, ufrag string, pwd string) error { + log.Debugf("Gathering ICE candidates") + + transportNet, err := newStdNet(cm.iFaceDiscover, cm.config.ICEConfig.InterfaceBlackList) + if err != nil { + log.Errorf("failed to create pion's stdnet: %s", err) + } + + agent, err := newAgent(cm.config, transportNet, candidateTypesP2P(), ufrag, pwd) + if err != nil { + return fmt.Errorf("create ICE agent: %w", err) + } + defer func() { + if err := agent.Close(); err != nil { + log.Warnf("Failed to close ICE agent: %v", err) + } + }() + + gatherDone := make(chan struct{}) + err = agent.OnCandidate(func(c ice.Candidate) { + log.Tracef("Got candidate: %v", c) + if c == nil { + close(gatherDone) + } + }) + if err != nil { + return fmt.Errorf("set ICE candidate handler: %w", err) + } + + if err := agent.GatherCandidates(); err != nil { + return fmt.Errorf("gather ICE candidates: %w", err) + } + + ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout) + defer cancel() + + select { + case <-ctx.Done(): + return fmt.Errorf("wait for gathering: %w", ctx.Err()) + case <-gatherDone: + } + + candidates, err := agent.GetLocalCandidates() + if err != nil { + return fmt.Errorf("get local candidates: %w", err) + } + log.Tracef("Got candidates: %v", candidates) + + if changed := cm.updateCandidates(candidates); changed { + select { + case localCandidatesChanged <- struct{}{}: + default: + } + } + + return nil +} + +func (cm *ConnMonitor) updateCandidates(newCandidates []ice.Candidate) bool { + cm.candidatesMu.Lock() + defer cm.candidatesMu.Unlock() + + if len(cm.currentCandidates) != len(newCandidates) { + cm.currentCandidates = newCandidates + return true + } + + for i, candidate := range cm.currentCandidates { + if candidate.Address() != newCandidates[i].Address() { + cm.currentCandidates = newCandidates + return true + } + } + + return false +} + +func (cm *ConnMonitor) triggerReconnect() { + select { + case cm.reconnectCh <- struct{}{}: + default: + } +} diff --git a/client/internal/peer/stdnet.go b/client/internal/peer/stdnet.go index ae31ebbf0..96d211dbc 100644 --- a/client/internal/peer/stdnet.go +++ b/client/internal/peer/stdnet.go @@ -6,6 +6,6 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" ) -func (w *WorkerICE) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNet(w.config.ICEConfig.InterfaceBlackList) +func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNet(ifaceBlacklist) } diff --git a/client/internal/peer/stdnet_android.go b/client/internal/peer/stdnet_android.go index b411405bb..a39a03b1c 100644 --- a/client/internal/peer/stdnet_android.go +++ b/client/internal/peer/stdnet_android.go @@ -2,6 +2,6 @@ package peer import "github.com/netbirdio/netbird/client/internal/stdnet" -func (w *WorkerICE) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNetWithDiscover(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList) +func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist) } diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index c4e9d1950..c86c1858f 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -233,41 +233,16 @@ func (w *WorkerICE) Close() { } func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) { - transportNet, err := w.newStdNet() + transportNet, err := newStdNet(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList) if err != nil { w.log.Errorf("failed to create pion's stdnet: %s", err) } - iceKeepAlive := iceKeepAlive() - iceDisconnectedTimeout := iceDisconnectedTimeout() - iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() - - agentConfig := &ice.AgentConfig{ - MulticastDNSMode: ice.MulticastDNSModeDisabled, - NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, - Urls: w.config.ICEConfig.StunTurn.Load().([]*stun.URI), - CandidateTypes: relaySupport, - InterfaceFilter: stdnet.InterfaceFilter(w.config.ICEConfig.InterfaceBlackList), - UDPMux: w.config.ICEConfig.UDPMux, - UDPMuxSrflx: w.config.ICEConfig.UDPMuxSrflx, - NAT1To1IPs: w.config.ICEConfig.NATExternalIPs, - Net: transportNet, - FailedTimeout: &failedTimeout, - DisconnectedTimeout: &iceDisconnectedTimeout, - KeepaliveInterval: &iceKeepAlive, - RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, - LocalUfrag: w.localUfrag, - LocalPwd: w.localPwd, - } - - if w.config.ICEConfig.DisableIPv6Discovery { - agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} - } - w.sentExtraSrflx = false - agent, err := ice.NewAgent(agentConfig) + + agent, err := newAgent(w.config, transportNet, relaySupport, w.localUfrag, w.localPwd) if err != nil { - return nil, err + return nil, fmt.Errorf("create agent: %w", err) } err = agent.OnCandidate(w.onICECandidate) @@ -390,6 +365,36 @@ func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferA } } +func newAgent(config ConnConfig, transportNet *stdnet.Net, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) { + iceKeepAlive := iceKeepAlive() + iceDisconnectedTimeout := iceDisconnectedTimeout() + iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() + + agentConfig := &ice.AgentConfig{ + MulticastDNSMode: ice.MulticastDNSModeDisabled, + NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, + Urls: config.ICEConfig.StunTurn.Load().([]*stun.URI), + CandidateTypes: candidateTypes, + InterfaceFilter: stdnet.InterfaceFilter(config.ICEConfig.InterfaceBlackList), + UDPMux: config.ICEConfig.UDPMux, + UDPMuxSrflx: config.ICEConfig.UDPMuxSrflx, + NAT1To1IPs: config.ICEConfig.NATExternalIPs, + Net: transportNet, + FailedTimeout: &failedTimeout, + DisconnectedTimeout: &iceDisconnectedTimeout, + KeepaliveInterval: &iceKeepAlive, + RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, + LocalUfrag: ufrag, + LocalPwd: pwd, + } + + if config.ICEConfig.DisableIPv6Discovery { + agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} + } + + return ice.NewAgent(agentConfig) +} + func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { relatedAdd := candidate.RelatedAddress() return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ diff --git a/client/server/server_test.go b/client/server/server_test.go index 9b18df4d3..e534ad7e2 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -110,7 +110,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve return nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir) + store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), "", config.Datadir) if err != nil { return nil, "", err } diff --git a/client/testdata/store.json b/client/testdata/store.json deleted file mode 100644 index 8236f2703..000000000 --- a/client/testdata/store.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "Accounts": { - "bf1c8084-ba50-4ce7-9439-34653001fc3b": { - "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", - "SetupKeys": { - "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "Revoked": false, - "UsedTimes": 0 - - } - }, - "Network": { - "Id": "af1c8024-ha40-4ce2-9418-34653101fc3c", - "Net": { - "IP": "100.64.0.0", - "Mask": "//8AAA==" - }, - "Dns": null - }, - "Peers": {}, - "Users": { - "edafee4e-63fb-11ec-90d6-0242ac120003": { - "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "Role": "admin" - }, - "f4f6d672-63fb-11ec-90d6-0242ac120003": { - "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "Role": "user" - } - } - } - } -} \ No newline at end of file diff --git a/client/testdata/store.sqlite b/client/testdata/store.sqlite new file mode 100644 index 000000000..118c2bebc Binary files /dev/null and b/client/testdata/store.sqlite differ diff --git a/management/client/client_test.go b/management/client/client_test.go index a082e354b..313a67617 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -47,25 +47,18 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { level, _ := log.ParseLevel("debug") log.SetLevel(level) - testDir := t.TempDir() - config := &mgmt.Config{} _, err := util.ReadJson("../server/testdata/management.json", config) if err != nil { t.Fatal(err) } - config.Datadir = testDir - err = util.CopyFileContents("../server/testdata/store.json", filepath.Join(testDir, "store.json")) - if err != nil { - t.Fatal(err) - } lis, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir) + store, cleanUp, err := NewSqliteTestStore(t, context.Background(), "../server/testdata/store.sqlite") if err != nil { t.Fatal(err) } @@ -521,3 +514,22 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) { assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match") assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match") } + +func NewSqliteTestStore(t *testing.T, ctx context.Context, testFile string) (mgmt.Store, func(), error) { + t.Helper() + dataDir := t.TempDir() + err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db")) + if err != nil { + t.Fatal(err) + } + + store, err := mgmt.NewSqliteStore(ctx, dataDir, nil) + if err != nil { + return nil, nil, err + } + + return store, func() { + store.Close(ctx) + os.Remove(filepath.Join(dataDir, "store.db")) + }, nil +} diff --git a/management/cmd/management.go b/management/cmd/management.go index 78b1a8d63..719d1a78c 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -475,7 +475,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) { loadedConfig := &server.Config{} - _, err := util.ReadJson(mgmtConfigPath, loadedConfig) + _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig) if err != nil { return nil, err } diff --git a/management/server/account.go b/management/server/account.go index 1463ae033..6ee0015f8 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -20,6 +20,11 @@ import ( cacheStore "github.com/eko/gocache/v3/store" "github.com/hashicorp/go-multierror" "github.com/miekg/dns" + gocache "github.com/patrickmn/go-cache" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" @@ -36,10 +41,6 @@ import ( "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/route" - gocache "github.com/patrickmn/go-cache" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" ) const ( @@ -76,7 +77,8 @@ type AccountManager interface { SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) - GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) + AccountExists(ctx context.Context, accountID string) (bool, error) + GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) @@ -96,6 +98,7 @@ type AccountManager interface { DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) + UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) @@ -842,55 +845,54 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer { return a.Peers[peerID] } -// SetJWTGroups updates the user's auto groups by synchronizing JWT groups. -// Returns true if there are changes in the JWT group membership. -func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { - user, ok := a.Users[userID] - if !ok { - return false - } - +// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. +// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups, +// newly groups to create and an error if any occurred. +func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) { existedGroupsByName := make(map[string]*nbgroup.Group) - for _, group := range a.Groups { + for _, group := range groups { existedGroupsByName[group.Name] = group } - newAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, a.Groups) - groupsToAdd := difference(groupsNames, maps.Keys(jwtGroupsMap)) - groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupsNames) + newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups) + + groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap)) + groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames) // If no groups are added or removed, we should not sync account if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { - return false + return false, nil, nil, nil } + newGroupsToCreate := make([]*nbgroup.Group, 0) + var modified bool for _, name := range groupsToAdd { group, exists := existedGroupsByName[name] if !exists { group = &nbgroup.Group{ - ID: xid.New().String(), - Name: name, - Issued: nbgroup.GroupIssuedJWT, + ID: xid.New().String(), + AccountID: user.AccountID, + Name: name, + Issued: nbgroup.GroupIssuedJWT, } - a.Groups[group.ID] = group + newGroupsToCreate = append(newGroupsToCreate, group) } if group.Issued == nbgroup.GroupIssuedJWT { - newAutoGroups = append(newAutoGroups, group.ID) + newUserAutoGroups = append(newUserAutoGroups, group.ID) modified = true } } for name, id := range jwtGroupsMap { if !slices.Contains(groupsToRemove, name) { - newAutoGroups = append(newAutoGroups, id) + newUserAutoGroups = append(newUserAutoGroups, id) continue } modified = true } - user.AutoGroups = newAutoGroups - return modified + return modified, newUserAutoGroups, newGroupsToCreate, nil } // UserGroupsAddToPeers adds groups to all peers of user @@ -1261,37 +1263,36 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return nil } -// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided. -// If an accountID is provided, it checks if the account exists and returns it. -// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID. +// AccountExists checks if an account exists. +func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { + return am.Store.AccountExists(ctx, LockingStrengthShare, accountID) +} + +// GetAccountIDByUserID retrieves the account ID based on the userID provided. +// If user does have an account, it returns the user's account ID. // If the user doesn't have an account, it creates one using the provided domain. // Returns the account ID or an error if none is found or created. -func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) { - if accountID != "" { - exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID) - if err != nil { - return "", err - } - if !exists { - return "", status.Errorf(status.NotFound, "account %s does not exist", accountID) - } - return accountID, nil +func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) { + if userID == "" { + return "", status.Errorf(status.NotFound, "no valid userID provided") } - if userID != "" { - account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) - if err != nil { - return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) - } + accountID, err := am.Store.GetAccountIDByUserID(userID) + if err != nil { + if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { + account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) + if err != nil { + return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) + } - if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { - return "", err + if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { + return "", err + } + return account.Id, nil } - - return account.Id, nil + return "", err } - - return "", status.Errorf(status.NotFound, "no valid userID or accountID provided") + return accountID, nil } func isNil(i idp.Manager) bool { @@ -1764,7 +1765,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s return nil, err } - if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) { + if user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") } @@ -1795,6 +1796,10 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) } + if user.AccountID != accountID { + return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID) + } + if !user.IsServiceUser && claims.Invited { err = am.redeemInvite(ctx, accountID, user.Id) if err != nil { @@ -1802,7 +1807,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai } } - if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil { + if err = am.syncJWTGroups(ctx, accountID, claims); err != nil { return "", "", err } @@ -1811,7 +1816,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. -func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error { +func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return err @@ -1822,69 +1827,134 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } if settings.JWTGroupsClaimName == "" { - log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") + log.WithContext(ctx).Debugf("JWT groups are enabled but no claim name is set") return nil } - // TODO: Remove GetAccount after refactoring account peer's update - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) - oldGroups := make([]string, len(user.AutoGroups)) - copy(oldGroups, user.AutoGroups) + unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer func() { + if unlockPeer != nil { + unlockPeer() + } + }() - // Update the account if group membership changes - if account.SetJWTGroups(claims.UserId, jwtGroupsNames) { - addNewGroups := difference(user.AutoGroups, oldGroups) - removeOldGroups := difference(oldGroups, user.AutoGroups) - - if settings.GroupsPropagationEnabled { - account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) - account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) - account.Network.IncSerial() + var addNewGroups []string + var removeOldGroups []string + var hasChanges bool + var user *User + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + user, err = am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + if err != nil { + return fmt.Errorf("error getting user: %w", err) } - if err := am.Store.SaveAccount(ctx, account); err != nil { - log.WithContext(ctx).Errorf("failed to save account: %v", err) + groups, err := am.Store.GetAccountGroups(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account groups: %w", err) + } + + changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, jwtGroupsNames) + if err != nil { + return fmt.Errorf("error getting JWT groups changes: %w", err) + } + + hasChanges = changed + // skip update if no changes + if !changed { return nil } + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { + return fmt.Errorf("error saving groups: %w", err) + } + + addNewGroups = difference(updatedAutoGroups, user.AutoGroups) + removeOldGroups = difference(user.AutoGroups, updatedAutoGroups) + + user.AutoGroups = updatedAutoGroups + if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil { + return fmt.Errorf("error saving user: %w", err) + } + // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) { - am.updateAccountPeers(ctx, account) + groups, err = transaction.GetAccountGroups(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account groups: %w", err) } + + groupsMap := make(map[string]*nbgroup.Group, len(groups)) + for _, group := range groups { + groupsMap[group.ID] = group + } + + peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId) + if err != nil { + return fmt.Errorf("error getting user peers: %w", err) + } + + updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups) + if err != nil { + return fmt.Errorf("error modifying user peers in groups: %w", err) + } + + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil { + return fmt.Errorf("error saving groups: %w", err) + } + + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + return fmt.Errorf("error incrementing network serial: %w", err) + } + } + unlockPeer() + unlockPeer = nil + + return nil + }) + if err != nil { + return err + } + + if !hasChanges { + return nil + } + + for _, g := range addNewGroups { + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + if err != nil { + log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) + } else { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, + } + am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta) + } + } + + for _, g := range removeOldGroups { + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + if err != nil { + log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) + } else { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, + } + am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta) + } + } + + if settings.GroupsPropagationEnabled { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) } - for _, g := range addNewGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) - } - } - - for _, g := range removeOldGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) - } - } + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) + am.updateAccountPeers(ctx, account) } return nil @@ -1917,7 +1987,17 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context // if Account ID is part of the claims // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) + if claims.AccountId != "" { + exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId) + if err != nil { + return "", err + } + if !exists { + return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId) + } + return claims.AccountId, nil + } + return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) } else if claims.AccountId != "" { userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) if err != nil { @@ -2230,7 +2310,11 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac routes := make(map[route.ID]*route.Route) setupKeys := map[string]*SetupKey{} nameServersGroups := make(map[string]*nbdns.NameServerGroup) - users[userID] = NewOwnerUser(userID) + + owner := NewOwnerUser(userID) + owner.AccountID = accountID + users[userID] = owner + dnsSettings := DNSSettings{ DisabledManagementGroups: make([]string, 0), } @@ -2298,12 +2382,17 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { // separateGroups separates user's auto groups into non-JWT and JWT groups. // Returns the list of standard auto groups and a map of JWT auto groups, // where the keys are the group names and the values are the group IDs. -func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([]string, map[string]string) { +func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) { newAutoGroups := make([]string, 0) jwtAutoGroups := make(map[string]string) // map of group name to group ID + allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups)) + for _, group := range allGroups { + allGroupsMap[group.ID] = group + } + for _, id := range autoGroups { - if group, ok := allGroups[id]; ok { + if group, ok := allGroupsMap[id]; ok { if group.Issued == nbgroup.GroupIssuedJWT { jwtAutoGroups[group.Name] = id } else { @@ -2311,5 +2400,6 @@ func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([ } } } + return newAutoGroups, jwtAutoGroups } diff --git a/management/server/account_test.go b/management/server/account_test.go index e0fc94c88..9877a5510 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -633,7 +633,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) + accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain) require.NoError(t, err, "create init user failed") initAccount, err := manager.Store.GetAccount(context.Background(), accountID) @@ -671,17 +671,16 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { userId := "user-id" domain := "test.domain" - initAccount := newAccountWithId(context.Background(), "", userId, domain) + _ = newAccountWithId(context.Background(), "", userId, domain) manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID := initAccount.Id - accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain) + accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) require.NoError(t, err, "create init user failed") // as initAccount was created without account id we have to take the id after account initialization - // that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated + // that happens inside the GetAccountIDByUserID where the id is getting generated // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it - initAccount, err = manager.Store.GetAccount(context.Background(), accountID) + initAccount, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get init account failed") claims := jwtclaims.AuthorizationClaims{ @@ -885,7 +884,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { } } -func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { +func TestAccountManager_GetAccountByUserID(t *testing.T) { manager, err := createManager(t) if err != nil { t.Fatal(err) @@ -894,7 +893,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { userId := "test_user" - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, "") if err != nil { t.Fatal(err) } @@ -903,14 +902,13 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { return } - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") - if err != nil { - t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID) - } + exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID) + assert.NoError(t, err) + assert.True(t, exists, "expected to get existing account after creation using userid") - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), "", "") if err == nil { - t.Errorf("expected an error when user and account IDs are empty") + t.Errorf("expected an error when user ID is empty") } } @@ -1731,7 +1729,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) @@ -1746,7 +1744,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1758,7 +1756,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { }) require.NoError(t, err, "unable to add peer") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -1804,7 +1802,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1832,7 +1830,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. }, } - accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -1852,7 +1850,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1864,7 +1862,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test }) require.NoError(t, err, "unable to add peer") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -1912,7 +1910,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ @@ -1923,9 +1921,6 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) - accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") - require.NoError(t, err, "unable to get account by ID") - settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) require.NoError(t, err, "unable to get account settings") @@ -2261,8 +2256,12 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { } func TestAccount_SetJWTGroups(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + // create a new account account := &Account{ + Id: "accountID", Peers: map[string]*nbpeer.Peer{ "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, @@ -2273,62 +2272,120 @@ func TestAccount_SetJWTGroups(t *testing.T) { Groups: map[string]*group.Group{ "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, }, - Settings: &Settings{GroupsPropagationEnabled: true}, + Settings: &Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"}, Users: map[string]*User{ - "user1": {Id: "user1"}, - "user2": {Id: "user2"}, + "user1": {Id: "user1", AccountID: "accountID"}, + "user2": {Id: "user2", AccountID: "accountID"}, }, } + assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account") + t.Run("empty jwt groups", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{}) - assert.False(t, updated, "account should not be updated") - assert.Empty(t, account.Users["user1"].AutoGroups, "auto groups must be empty") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{}}, + } + err := manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Empty(t, user.AutoGroups, "auto groups must be empty") }) t.Run("jwt match existing api group", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{"group1"}) - assert.False(t, updated, "account should not be updated") - assert.Equal(t, 0, len(account.Users["user1"].AutoGroups)) - assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group1"}}, + } + err := manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 0) + + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + assert.NoError(t, err, "unable to get group") + assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { account.Users["user1"].AutoGroups = []string{"group1"} + assert.NoError(t, manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, account.Users["user1"])) - updated := account.SetJWTGroups("user1", []string{"group1"}) - assert.False(t, updated, "account should not be updated") - assert.Equal(t, 1, len(account.Users["user1"].AutoGroups)) - assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group1"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 1) + + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + assert.NoError(t, err, "unable to get group") + assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) t.Run("add jwt group", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{"group1", "group2"}) - assert.True(t, updated, "account should be updated") - assert.Len(t, account.Groups, 2, "new group should be added") - assert.Len(t, account.Users["user1"].AutoGroups, 2, "new group should be added") - assert.Contains(t, account.Groups, account.Users["user1"].AutoGroups[0], "groups must contain group2 from user groups") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group2"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) t.Run("existed group not update", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{"group2"}) - assert.False(t, updated, "account should not be updated") - assert.Len(t, account.Groups, 2, "groups count should not be changed") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group2"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) t.Run("add new group", func(t *testing.T) { - updated := account.SetJWTGroups("user2", []string{"group1", "group3"}) - assert.True(t, updated, "account should be updated") - assert.Len(t, account.Groups, 3, "new group should be added") - assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added") - assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user2", + Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group3"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID") + assert.NoError(t, err) + assert.Len(t, groups, 3, "new group3 should be added") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user2") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 1, "new group should be added") }) t.Run("remove all JWT groups", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{}) - assert.True(t, updated, "account should be updated") - assert.Len(t, account.Users["user1"].AutoGroups, 1, "only non-JWT groups should remain") - assert.Contains(t, account.Users["user1"].AutoGroups, "group1", " group1 should still be present") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") + assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present") }) } @@ -2428,7 +2485,7 @@ func createManager(t TB) (*DefaultAccountManager, error) { func createStore(t TB) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) + store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 53ab1eaaf..631a19785 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -212,7 +212,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { func createDNSStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) + store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/file_store.go b/management/server/file_store.go index 994a4b1ee..df3e9bb77 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -2,24 +2,18 @@ package server import ( "context" - "errors" - "net" "os" "path/filepath" "strings" "sync" "time" - "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/status" - "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/route" "github.com/rs/xid" log "github.com/sirupsen/logrus" + nbgroup "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" ) @@ -42,167 +36,9 @@ type FileStore struct { mux sync.Mutex `json:"-"` storeFile string `json:"-"` - // sync.Mutex indexed by resource ID - resourceLocks sync.Map `json:"-"` - globalAccountLock sync.Mutex `json:"-"` - metrics telemetry.AppMetrics `json:"-"` } -func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error { - return f(s) -} - -func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)] - if !ok { - return status.NewSetupKeyNotFoundError() - } - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - account.SetupKeys[setupKeyID].UsedTimes++ - - return s.SaveAccount(ctx, account) -} - -func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - allGroup, err := account.GetGroupAll() - if err != nil || allGroup == nil { - return errors.New("all group not found") - } - - allGroup.Peers = append(allGroup.Peers, peerID) - - return nil -} - -func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountId) - if err != nil { - return err - } - - account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId) - - return nil -} - -func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, ok := s.Accounts[peer.AccountID] - if !ok { - return status.NewAccountNotFoundError(peer.AccountID) - } - - account.Peers[peer.ID] = peer - return s.SaveAccount(ctx, account) -} - -func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, ok := s.Accounts[accountId] - if !ok { - return status.NewAccountNotFoundError(accountId) - } - - account.Network.Serial++ - - return s.SaveAccount(ctx, account) -} - -func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)] - if !ok { - return nil, status.NewSetupKeyNotFoundError() - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - setupKey, ok := account.SetupKeys[key] - if !ok { - return nil, status.Errorf(status.NotFound, "setup key not found") - } - - return setupKey, nil -} - -func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - var takenIps []net.IP - for _, existingPeer := range account.Peers { - takenIps = append(takenIps, existingPeer.IP) - } - - return takenIps, nil -} - -func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - existingLabels := []string{} - for _, peer := range account.Peers { - if peer.DNSLabel != "" { - existingLabels = append(existingLabels, peer.DNSLabel) - } - } - return existingLabels, nil -} - -func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Network, nil -} - -type StoredAccount struct{} - // NewFileStore restores a store from the file located in the datadir func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { fs, err := restore(ctx, filepath.Join(dataDir, storeFileName)) @@ -213,25 +49,6 @@ func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetr return fs, nil } -// NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir -func NewFilestoreFromSqliteStore(ctx context.Context, sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { - store, err := NewFileStore(ctx, dataDir, metrics) - if err != nil { - return nil, err - } - - err = store.SaveInstallationID(ctx, sqlStore.GetInstallationID()) - if err != nil { - return nil, err - } - - for _, account := range sqlStore.GetAllAccounts(ctx) { - store.Accounts[account.Id] = account - } - - return store, store.persist(ctx, store.storeFile) -} - // restore the state of the store from the file. // Creates a new empty store file if doesn't exist func restore(ctx context.Context, file string) (*FileStore, error) { @@ -240,7 +57,6 @@ func restore(ctx context.Context, file string) (*FileStore, error) { s := &FileStore{ Accounts: make(map[string]*Account), mux: sync.Mutex{}, - globalAccountLock: sync.Mutex{}, SetupKeyID2AccountID: make(map[string]string), PeerKeyID2AccountID: make(map[string]string), UserID2AccountID: make(map[string]string), @@ -416,252 +232,6 @@ func (s *FileStore) persist(ctx context.Context, file string) error { return nil } -// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock -func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) { - log.WithContext(ctx).Debugf("acquiring global lock") - start := time.Now() - s.globalAccountLock.Lock() - - unlock = func() { - s.globalAccountLock.Unlock() - log.WithContext(ctx).Debugf("released global lock in %v", time.Since(start)) - } - - took := time.Since(start) - log.WithContext(ctx).Debugf("took %v to acquire global lock", took) - if s.metrics != nil { - s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took) - } - - return unlock -} - -// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock -func (s *FileStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) { - log.WithContext(ctx).Debugf("acquiring lock for ID %s", uniqueID) - start := time.Now() - value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.Mutex{}) - mtx := value.(*sync.Mutex) - mtx.Lock() - - unlock = func() { - mtx.Unlock() - log.WithContext(ctx).Debugf("released lock for ID %s in %v", uniqueID, time.Since(start)) - } - - return unlock -} - -// AcquireReadLockByUID acquires an ID lock for reading a resource and returns a function that releases the lock -// This method is still returns a write lock as file store can't handle read locks -func (s *FileStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) { - return s.AcquireWriteLockByUID(ctx, uniqueID) -} - -func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error { - s.mux.Lock() - defer s.mux.Unlock() - - if account.Id == "" { - return status.Errorf(status.InvalidArgument, "account id should not be empty") - } - - accountCopy := account.Copy() - - s.Accounts[accountCopy.Id] = accountCopy - - // todo check that account.Id and keyId are not exist already - // because if keyId exists for other accounts this can be bad - for keyID := range accountCopy.SetupKeys { - s.SetupKeyID2AccountID[strings.ToUpper(keyID)] = accountCopy.Id - } - - // enforce peer to account index and delete peer to route indexes for rebuild - for _, peer := range accountCopy.Peers { - s.PeerKeyID2AccountID[peer.Key] = accountCopy.Id - s.PeerID2AccountID[peer.ID] = accountCopy.Id - } - - for _, user := range accountCopy.Users { - s.UserID2AccountID[user.Id] = accountCopy.Id - for _, pat := range user.PATs { - s.TokenID2UserID[pat.ID] = user.Id - s.HashedPAT2TokenID[pat.HashedToken] = pat.ID - } - } - - if accountCopy.DomainCategory == PrivateCategory && accountCopy.IsDomainPrimaryAccount { - s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id - } - - return s.persist(ctx, s.storeFile) -} - -func (s *FileStore) DeleteAccount(ctx context.Context, account *Account) error { - s.mux.Lock() - defer s.mux.Unlock() - - if account.Id == "" { - return status.Errorf(status.InvalidArgument, "account id should not be empty") - } - - for keyID := range account.SetupKeys { - delete(s.SetupKeyID2AccountID, strings.ToUpper(keyID)) - } - - // enforce peer to account index and delete peer to route indexes for rebuild - for _, peer := range account.Peers { - delete(s.PeerKeyID2AccountID, peer.Key) - delete(s.PeerID2AccountID, peer.ID) - } - - for _, user := range account.Users { - for _, pat := range user.PATs { - delete(s.TokenID2UserID, pat.ID) - delete(s.HashedPAT2TokenID, pat.HashedToken) - } - delete(s.UserID2AccountID, user.Id) - } - - if account.DomainCategory == PrivateCategory && account.IsDomainPrimaryAccount { - delete(s.PrivateDomain2AccountID, account.Domain) - } - - delete(s.Accounts, account.Id) - - return s.persist(ctx, s.storeFile) -} - -// DeleteHashedPAT2TokenIDIndex removes an entry from the indexing map HashedPAT2TokenID -func (s *FileStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error { - s.mux.Lock() - defer s.mux.Unlock() - - delete(s.HashedPAT2TokenID, hashedToken) - - return nil -} - -// DeleteTokenID2UserIDIndex removes an entry from the indexing map TokenID2UserID -func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - delete(s.TokenID2UserID, tokenID) - - return nil -} - -// GetAccountByPrivateDomain returns account by private domain -func (s *FileStore) GetAccountByPrivateDomain(_ context.Context, domain string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PrivateDomain2AccountID[strings.ToLower(domain)] - if !ok { - return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Copy(), nil -} - -// GetAccountBySetupKey returns account by setup key id -func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] - if !ok { - return nil, status.NewSetupKeyNotFoundError() - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Copy(), nil -} - -// GetTokenIDByHashedToken returns the id of a personal access token by its hashed secret -func (s *FileStore) GetTokenIDByHashedToken(_ context.Context, token string) (string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - tokenID, ok := s.HashedPAT2TokenID[token] - if !ok { - return "", status.Errorf(status.NotFound, "tokenID not found: provided token doesn't exists") - } - - return tokenID, nil -} - -// GetUserByTokenID returns a User object a tokenID belongs to -func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User, error) { - s.mux.Lock() - defer s.mux.Unlock() - - userID, ok := s.TokenID2UserID[tokenID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found: provided tokenID doesn't exists") - } - - accountID, ok := s.UserID2AccountID[userID] - if !ok { - return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists") - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Users[userID].Copy(), nil -} - -func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) { - accountID, ok := s.UserID2AccountID[userID] - if !ok { - return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists") - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - user := account.Users[userID].Copy() - pat := make([]PersonalAccessToken, 0, len(user.PATs)) - for _, token := range user.PATs { - if token != nil { - pat = append(pat, *token) - } - } - user.PATsG = pat - - return user, nil -} - -func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) { - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - groupsSlice := make([]*nbgroup.Group, 0, len(account.Groups)) - - for _, group := range account.Groups { - groupsSlice = append(groupsSlice, group) - } - - return groupsSlice, nil -} - // GetAllAccounts returns all accounts func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { s.mux.Lock() @@ -673,278 +243,6 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { return all } -// getAccount returns a reference to the Account. Should not return a copy. -func (s *FileStore) getAccount(accountID string) (*Account, error) { - account, ok := s.Accounts[accountID] - if !ok { - return nil, status.NewAccountNotFoundError(accountID) - } - - return account, nil -} - -// GetAccount returns an account for ID -func (s *FileStore) GetAccount(_ context.Context, accountID string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Copy(), nil -} - -// GetAccountByUser returns a user account -func (s *FileStore) GetAccountByUser(_ context.Context, userID string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.UserID2AccountID[userID] - if !ok { - return nil, status.NewUserNotFoundError(userID) - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Copy(), nil -} - -// GetAccountByPeerID returns an account for a given peer ID -func (s *FileStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PeerID2AccountID[peerID] - if !ok { - return nil, status.Errorf(status.NotFound, "provided peer ID doesn't exists %s", peerID) - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - // this protection is needed because when we delete a peer, we don't really remove index peerID -> accountID. - // check Account.Peers for a match - if _, ok := account.Peers[peerID]; !ok { - delete(s.PeerID2AccountID, peerID) - log.WithContext(ctx).Warnf("removed stale peerID %s to accountID %s index", peerID, accountID) - return nil, status.NewPeerNotFoundError(peerID) - } - - return account.Copy(), nil -} - -// GetAccountByPeerPubKey returns an account for a given peer WireGuard public key -func (s *FileStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PeerKeyID2AccountID[peerKey] - if !ok { - return nil, status.NewPeerNotFoundError(peerKey) - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - // this protection is needed because when we delete a peer, we don't really remove index peerKey -> accountID. - // check Account.Peers for a match - stale := true - for _, peer := range account.Peers { - if peer.Key == peerKey { - stale = false - break - } - } - if stale { - delete(s.PeerKeyID2AccountID, peerKey) - log.WithContext(ctx).Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID) - return nil, status.NewPeerNotFoundError(peerKey) - } - - return account.Copy(), nil -} - -func (s *FileStore) GetAccountIDByPeerPubKey(_ context.Context, peerKey string) (string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PeerKeyID2AccountID[peerKey] - if !ok { - return "", status.NewPeerNotFoundError(peerKey) - } - - return accountID, nil -} - -func (s *FileStore) GetAccountIDByUserID(userID string) (string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.UserID2AccountID[userID] - if !ok { - return "", status.NewUserNotFoundError(userID) - } - - return accountID, nil -} - -func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] - if !ok { - return "", status.NewSetupKeyNotFoundError() - } - - return accountID, nil -} - -func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PeerKeyID2AccountID[peerKey] - if !ok { - return nil, status.NewPeerNotFoundError(peerKey) - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - for _, peer := range account.Peers { - if peer.Key == peerKey { - return peer.Copy(), nil - } - } - - return nil, status.NewPeerNotFoundError(peerKey) -} - -func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Settings.Copy(), nil -} - -// GetInstallationID returns the installation ID from the store -func (s *FileStore) GetInstallationID() string { - return s.InstallationID -} - -// SaveInstallationID saves the installation ID -func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - s.InstallationID = ID - - return s.persist(ctx, s.storeFile) -} - -// SavePeer saves the peer in the account -func (s *FileStore) SavePeer(_ context.Context, accountID string, peer *nbpeer.Peer) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - newPeer := peer.Copy() - - account.Peers[peer.ID] = newPeer - - s.PeerKeyID2AccountID[peer.Key] = accountID - s.PeerID2AccountID[peer.ID] = accountID - - return nil -} - -// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things. -// PeerStatus will be saved eventually when some other changes occur. -func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - peer := account.Peers[peerID] - if peer == nil { - return status.Errorf(status.NotFound, "peer %s not found", peerID) - } - - peer.Status = &peerStatus - - return nil -} - -// SavePeerLocation stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things. -// Peer.Location will be saved eventually when some other changes occur. -func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - peer := account.Peers[peerWithLocation.ID] - if peer == nil { - return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID) - } - - peer.Location = peerWithLocation.Location - - return nil -} - -// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things. -func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - peer := account.Users[userID] - if peer == nil { - return status.Errorf(status.NotFound, "user %s not found", userID) - } - - peer.LastLogin = lastLogin - - return nil -} - -func (s *FileStore) GetPostureCheckByChecksDefinition(_ string, _ *posture.ChecksDefinition) (*posture.Checks, error) { - return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented") -} - // Close the FileStore persisting data to disk func (s *FileStore) Close(ctx context.Context) error { s.mux.Lock() @@ -959,86 +257,3 @@ func (s *FileStore) Close(ctx context.Context) error { func (s *FileStore) GetStoreEngine() StoreEngine { return FileStoreEngine } - -func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error { - return status.Errorf(status.Internal, "SaveUsers is not implemented") -} - -func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error { - return status.Errorf(status.Internal, "SaveGroups is not implemented") -} - -func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) { - return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented") -} - -func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return "", "", err - } - - return account.Domain, account.DomainCategory, nil -} - -// AccountExists checks whether an account exists by the given ID. -func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) { - _, exists := s.Accounts[id] - return exists, nil -} - -func (s *FileStore) GetAccountDNSSettings(_ context.Context, _ LockingStrength, _ string) (*DNSSettings, error) { - return nil, status.Errorf(status.Internal, "GetAccountDNSSettings is not implemented") -} - -func (s *FileStore) GetGroupByID(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { - return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented") -} - -func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { - return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented") -} - -func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) { - return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented") -} - -func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) { - return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented") - -} - -func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) { - return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented") -} - -func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) { - return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented") -} - -func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) { - return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented") -} - -func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, _ string) (*route.Route, error) { - return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented") -} - -func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) { - return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented") -} - -func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*SetupKey, error) { - return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented") -} - -func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) { - return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented") -} - -func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) { - return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented") -} diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go deleted file mode 100644 index 56e46b696..000000000 --- a/management/server/file_store_test.go +++ /dev/null @@ -1,655 +0,0 @@ -package server - -import ( - "context" - "crypto/sha256" - "net" - "path/filepath" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/management/server/group" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/util" -) - -type accounts struct { - Accounts map[string]*Account -} - -func TestStalePeerIndices(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") - require.NoError(t, err) - - peerID := "some_peer" - peerKey := "some_peer_key" - account.Peers[peerID] = &nbpeer.Peer{ - ID: peerID, - Key: peerKey, - } - - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - account.DeletePeer(peerID) - - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - _, err = store.GetAccountByPeerID(context.Background(), peerID) - require.Error(t, err, "expecting to get an error when found stale index") - - _, err = store.GetAccountByPeerPubKey(context.Background(), peerKey) - require.Error(t, err, "expecting to get an error when found stale index") -} - -func TestNewStore(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - - if store.Accounts == nil || len(store.Accounts) != 0 { - t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") - } - - if store.SetupKeyID2AccountID == nil || len(store.SetupKeyID2AccountID) != 0 { - t.Errorf("expected to create a new empty SetupKeyID2AccountID map when creating a new FileStore") - } - - if store.PeerKeyID2AccountID == nil || len(store.PeerKeyID2AccountID) != 0 { - t.Errorf("expected to create a new empty PeerKeyID2AccountID map when creating a new FileStore") - } - - if store.UserID2AccountID == nil || len(store.UserID2AccountID) != 0 { - t.Errorf("expected to create a new empty UserID2AccountID map when creating a new FileStore") - } - - if store.HashedPAT2TokenID == nil || len(store.HashedPAT2TokenID) != 0 { - t.Errorf("expected to create a new empty HashedPAT2TokenID map when creating a new FileStore") - } - - if store.TokenID2UserID == nil || len(store.TokenID2UserID) != 0 { - t.Errorf("expected to create a new empty TokenID2UserID map when creating a new FileStore") - } -} - -func TestSaveAccount(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - - account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } - - // SaveAccount should trigger persist - err := store.SaveAccount(context.Background(), account) - if err != nil { - return - } - - if store.Accounts[account.Id] == nil { - t.Errorf("expecting Account to be stored after SaveAccount()") - } - - if store.PeerKeyID2AccountID["peerkey"] == "" { - t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount()") - } - - if store.UserID2AccountID["testuser"] == "" { - t.Errorf("expecting UserID2AccountID index updated after SaveAccount()") - } - - if store.SetupKeyID2AccountID[setupKey.Key] == "" { - t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount()") - } -} - -func TestDeleteAccount(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - defer store.Close(context.Background()) - - var account *Account - for _, a := range store.Accounts { - account = a - break - } - - require.NotNil(t, account, "failed to restore a FileStore file and get at least one account") - - err = store.DeleteAccount(context.Background(), account) - require.NoError(t, err, "failed to delete account, error: %v", err) - - _, ok := store.Accounts[account.Id] - require.False(t, ok, "failed to delete account") - - for id := range account.Users { - _, ok := store.UserID2AccountID[id] - assert.False(t, ok, "failed to delete UserID2AccountID index") - for _, pat := range account.Users[id].PATs { - _, ok := store.HashedPAT2TokenID[pat.HashedToken] - assert.False(t, ok, "failed to delete HashedPAT2TokenID index") - _, ok = store.TokenID2UserID[pat.ID] - assert.False(t, ok, "failed to delete TokenID2UserID index") - } - } - - for _, p := range account.Peers { - _, ok := store.PeerKeyID2AccountID[p.Key] - assert.False(t, ok, "failed to delete PeerKeyID2AccountID index") - _, ok = store.PeerID2AccountID[p.ID] - assert.False(t, ok, "failed to delete PeerID2AccountID index") - } - - for id := range account.SetupKeys { - _, ok := store.SetupKeyID2AccountID[id] - assert.False(t, ok, "failed to delete SetupKeyID2AccountID index") - } - - _, ok = store.PrivateDomain2AccountID[account.Domain] - assert.False(t, ok, "failed to delete PrivateDomain2AccountID index") - -} - -func TestStore(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - - account := newAccountWithId(context.Background(), "account_id", "testuser", "") - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } - account.Groups["all"] = &group.Group{ - ID: "all", - Name: "all", - Peers: []string{"testpeer"}, - } - account.Policies = append(account.Policies, &Policy{ - ID: "all", - Name: "all", - Enabled: true, - Rules: []*PolicyRule{ - { - ID: "all", - Name: "all", - Sources: []string{"all"}, - Destinations: []string{"all"}, - }, - }, - }) - account.Policies = append(account.Policies, &Policy{ - ID: "dmz", - Name: "dmz", - Enabled: true, - Rules: []*PolicyRule{ - { - ID: "dmz", - Name: "dmz", - Enabled: true, - Sources: []string{"all"}, - Destinations: []string{"all"}, - }, - }, - }) - - // SaveAccount should trigger persist - err := store.SaveAccount(context.Background(), account) - if err != nil { - return - } - - restored, err := NewFileStore(context.Background(), store.storeFile, nil) - if err != nil { - return - } - - restoredAccount := restored.Accounts[account.Id] - if restoredAccount == nil { - t.Errorf("failed to restore a FileStore file - missing Account %s", account.Id) - return - } - - if restoredAccount.Peers["testpeer"] == nil { - t.Errorf("failed to restore a FileStore file - missing Peer testpeer") - } - - if restoredAccount.CreatedBy != "testuser" { - t.Errorf("failed to restore a FileStore file - missing Account CreatedBy") - } - - if restoredAccount.Users["testuser"] == nil { - t.Errorf("failed to restore a FileStore file - missing User testuser") - } - - if restoredAccount.Network == nil { - t.Errorf("failed to restore a FileStore file - missing Network") - } - - if restoredAccount.Groups["all"] == nil { - t.Errorf("failed to restore a FileStore file - missing Group all") - } - - if len(restoredAccount.Policies) != 2 { - t.Errorf("failed to restore a FileStore file - missing Policies") - return - } - - assert.Equal(t, account.Policies[0], restoredAccount.Policies[0], "failed to restore a FileStore file - missing Policy all") - assert.Equal(t, account.Policies[1], restoredAccount.Policies[1], "failed to restore a FileStore file - missing Policy dmz") -} - -func TestRestore(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - - require.NotNil(t, account, "failed to restore a FileStore file - missing account bf1c8084-ba50-4ce7-9439-34653001fc3b") - - require.NotNil(t, account.Users["edafee4e-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore file - missing Account User edafee4e-63fb-11ec-90d6-0242ac120003") - - require.NotNil(t, account.Users["f4f6d672-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore file - missing Account User f4f6d672-63fb-11ec-90d6-0242ac120003") - - require.NotNil(t, account.Network, "failed to restore a FileStore file - missing Account Network") - - require.NotNil(t, account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"], "failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB") - - require.NotNil(t, account.Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore wrong PATs length") - - require.Len(t, store.UserID2AccountID, 2, "failed to restore a FileStore wrong UserID2AccountID mapping length") - - require.Len(t, store.SetupKeyID2AccountID, 1, "failed to restore a FileStore wrong SetupKeyID2AccountID mapping length") - - require.Len(t, store.PrivateDomain2AccountID, 1, "failed to restore a FileStore wrong PrivateDomain2AccountID mapping length") - - require.Len(t, store.HashedPAT2TokenID, 1, "failed to restore a FileStore wrong HashedPAT2TokenID mapping length") - - require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length") -} - -func TestRestoreGroups_Migration(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - // create default group - account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - account.Groups = map[string]*group.Group{ - "cfefqs706sqkneg59g3g": { - ID: "cfefqs706sqkneg59g3g", - Name: "All", - }, - } - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err, "failed to save account") - - // restore account with default group with empty Issue field - if store, err = NewFileStore(context.Background(), storeDir, nil); err != nil { - return - } - account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - - require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups") - require.Equal(t, group.GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark") -} - -func TestGetAccountByPrivateDomain(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - existingDomain := "test.com" - - account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) - require.NoError(t, err, "should found account") - require.Equal(t, existingDomain, account.Domain, "domains should match") - - _, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com") - require.Error(t, err, "should return error on domain lookup") -} - -func TestFileStore_GetAccount(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - expected := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - if expected == nil { - t.Fatalf("expected account doesn't exist") - return - } - - account, err := store.GetAccount(context.Background(), expected.Id) - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, expected.IsDomainPrimaryAccount, account.IsDomainPrimaryAccount) - assert.Equal(t, expected.DomainCategory, account.DomainCategory) - assert.Equal(t, expected.Domain, account.Domain) - assert.Equal(t, expected.CreatedBy, account.CreatedBy) - assert.Equal(t, expected.Network.Identifier, account.Network.Identifier) - assert.Len(t, account.Peers, len(expected.Peers)) - assert.Len(t, account.Users, len(expected.Users)) - assert.Len(t, account.SetupKeys, len(expected.SetupKeys)) - assert.Len(t, account.Routes, len(expected.Routes)) - assert.Len(t, account.NameServerGroups, len(expected.NameServerGroups)) -} - -func TestFileStore_GetTokenIDByHashedToken(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - hashedToken := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].HashedToken - tokenID, err := store.GetTokenIDByHashedToken(context.Background(), hashedToken) - if err != nil { - t.Fatal(err) - } - - expectedTokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID - assert.Equal(t, expectedTokenID, tokenID) -} - -func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - store.HashedPAT2TokenID["someHashedToken"] = "someTokenId" - - err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken") - if err != nil { - t.Fatal(err) - } - - assert.Empty(t, store.HashedPAT2TokenID["someHashedToken"]) -} - -func TestFileStore_DeleteTokenID2UserIDIndex(t *testing.T) { - store := newStore(t) - store.TokenID2UserID["someTokenId"] = "someUserId" - - err := store.DeleteTokenID2UserIDIndex("someTokenId") - if err != nil { - t.Fatal(err) - } - - assert.Empty(t, store.TokenID2UserID["someTokenId"]) -} - -func TestFileStore_GetTokenIDByHashedToken_Failure(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - wrongToken := sha256.Sum256([]byte("someNotValidTokenThatFails1234")) - _, err = store.GetTokenIDByHashedToken(context.Background(), string(wrongToken[:])) - - assert.Error(t, err, "GetTokenIDByHashedToken should throw error if token invalid") -} - -func TestFileStore_GetUserByTokenID(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - tokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID - user, err := store.GetUserByTokenID(context.Background(), tokenID) - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id) -} - -func TestFileStore_GetUserByTokenID_Failure(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - wrongTokenID := "someNonExistingTokenID" - _, err = store.GetUserByTokenID(context.Background(), wrongTokenID) - - assert.Error(t, err, "GetUserByTokenID should throw error if tokenID invalid") -} - -func TestFileStore_SavePeerStatus(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - account, err := store.getAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") - if err != nil { - t.Fatal(err) - } - - // save status of non-existing peer - newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()} - err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) - assert.Error(t, err) - - // save new status of existing peer - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, - } - - err = store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatal(err) - } - - err = store.SavePeerStatus(account.Id, "testpeer", newStatus) - if err != nil { - t.Fatal(err) - } - account, err = store.getAccount(account.Id) - if err != nil { - t.Fatal(err) - } - - actual := account.Peers["testpeer"].Status - assert.Equal(t, newStatus, *actual) -} - -func TestFileStore_SavePeerLocation(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") - require.NoError(t, err) - - peer := &nbpeer.Peer{ - AccountID: account.Id, - ID: "testpeer", - Location: nbpeer.Location{ - ConnectionIP: net.ParseIP("10.0.0.0"), - CountryCode: "YY", - CityName: "City", - GeoNameID: 1, - }, - Meta: nbpeer.PeerSystemMeta{}, - } - // error is expected as peer is not in store yet - err = store.SavePeerLocation(account.Id, peer) - assert.Error(t, err) - - account.Peers[peer.ID] = peer - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - peer.Location.ConnectionIP = net.ParseIP("35.1.1.1") - peer.Location.CountryCode = "DE" - peer.Location.CityName = "Berlin" - peer.Location.GeoNameID = 2950159 - - err = store.SavePeerLocation(account.Id, account.Peers[peer.ID]) - assert.NoError(t, err) - - account, err = store.GetAccount(context.Background(), account.Id) - require.NoError(t, err) - - actual := account.Peers[peer.ID].Location - assert.Equal(t, peer.Location, actual) -} - -func newStore(t *testing.T) *FileStore { - t.Helper() - store, err := NewFileStore(context.Background(), t.TempDir(), nil) - if err != nil { - t.Errorf("failed creating a new store") - } - - return store -} diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index ff09129bd..f8ab46d81 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -6,7 +6,6 @@ import ( "io" "net" "os" - "path/filepath" "runtime" "sync" "sync/atomic" @@ -89,14 +88,7 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error func Test_SyncProtocol(t *testing.T) { dir := t.TempDir() - err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json")) - if err != nil { - t.Fatal(err) - } - defer func() { - os.Remove(filepath.Join(dir, "store.json")) //nolint - }() - mgmtServer, _, mgmtAddr, err := startManagementForTest(t, &Config{ + mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -117,6 +109,7 @@ func Test_SyncProtocol(t *testing.T) { Datadir: dir, HttpConfig: nil, }) + defer cleanup() if err != nil { t.Fatal(err) return @@ -412,18 +405,18 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { } } -func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) { +func startManagementForTest(t *testing.T, testFile string, config *Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) { t.Helper() lis, err := net.Listen("tcp", "localhost:0") if err != nil { - return nil, nil, "", err + return nil, nil, "", nil, err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir) + + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), testFile) if err != nil { - return nil, nil, "", err + t.Fatal(err) } - t.Cleanup(cleanUp) peersUpdateManager := NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} @@ -437,7 +430,8 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA eventStore, nil, false, MocIntegratedValidator{}, metrics) if err != nil { - return nil, nil, "", err + cleanup() + return nil, nil, "", cleanup, err } secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) @@ -445,7 +439,7 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA ephemeralMgr := NewEphemeralManager(store, accountManager) mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, ephemeralMgr) if err != nil { - return nil, nil, "", err + return nil, nil, "", cleanup, err } mgmtProto.RegisterManagementServiceServer(s, mgmtServer) @@ -455,7 +449,7 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA } }() - return s, accountManager, lis.Addr().String(), nil + return s, accountManager, lis.Addr().String(), cleanup, nil } func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn, error) { @@ -475,6 +469,7 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie return mgmtProto.NewManagementServiceClient(conn), conn, nil } + func Test_SyncStatusRace(t *testing.T) { if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" { t.Skip("Skipping on CI and Postgres store") @@ -488,15 +483,8 @@ func Test_SyncStatusRace(t *testing.T) { func testSyncStatusRace(t *testing.T) { t.Helper() dir := t.TempDir() - err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json")) - if err != nil { - t.Fatal(err) - } - defer func() { - os.Remove(filepath.Join(dir, "store.json")) //nolint - }() - mgmtServer, am, mgmtAddr, err := startManagementForTest(t, &Config{ + mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -517,6 +505,7 @@ func testSyncStatusRace(t *testing.T) { Datadir: dir, HttpConfig: nil, }) + defer cleanup() if err != nil { t.Fatal(err) return @@ -665,15 +654,8 @@ func Test_LoginPerformance(t *testing.T) { t.Run(bc.name, func(t *testing.T) { t.Helper() dir := t.TempDir() - err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json")) - if err != nil { - t.Fatal(err) - } - defer func() { - os.Remove(filepath.Join(dir, "store.json")) //nolint - }() - mgmtServer, am, _, err := startManagementForTest(t, &Config{ + mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -694,6 +676,7 @@ func Test_LoginPerformance(t *testing.T) { Datadir: dir, HttpConfig: nil, }) + defer cleanup() if err != nil { t.Fatal(err) return diff --git a/management/server/management_test.go b/management/server/management_test.go index 3956d96b1..ba27dc5e8 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -5,7 +5,6 @@ import ( "math/rand" "net" "os" - "path/filepath" "runtime" sync2 "sync" "time" @@ -52,8 +51,6 @@ var _ = Describe("Management service", func() { dataDir, err = os.MkdirTemp("", "wiretrustee_mgmt_test_tmp_*") Expect(err).NotTo(HaveOccurred()) - err = util.CopyFileContents("testdata/store.json", filepath.Join(dataDir, "store.json")) - Expect(err).NotTo(HaveOccurred()) var listener net.Listener config := &server.Config{} @@ -61,7 +58,7 @@ var _ = Describe("Management service", func() { Expect(err).NotTo(HaveOccurred()) config.Datadir = dataDir - s, listener = startServer(config) + s, listener = startServer(config, dataDir, "testdata/store.sqlite") addr = listener.Addr().String() client, conn = createRawClient(addr) @@ -530,12 +527,12 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie return mgmtProto.NewManagementServiceClient(conn), conn } -func startServer(config *server.Config) (*grpc.Server, net.Listener) { +func startServer(config *server.Config, dataDir string, testFile string) (*grpc.Server, net.Listener) { lis, err := net.Listen("tcp", ":0") Expect(err).NotTo(HaveOccurred()) s := grpc.NewServer() - store, _, err := server.NewTestStoreFromJson(context.Background(), config.Datadir) + store, _, err := server.NewTestStoreFromSqlite(context.Background(), testFile, dataDir) if err != nil { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 8dc8f6a4f..681bf533a 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -27,7 +27,8 @@ type MockAccountManager struct { CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) - GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error) + AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) + GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) @@ -57,7 +58,7 @@ type MockAccountManager struct { MarkPATUsedFunc func(ctx context.Context, pat string) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups,accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error @@ -193,14 +194,22 @@ func (am *MockAccountManager) CreateSetupKey( return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") } -// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface -func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) { - if am.GetAccountIDByUserOrAccountIdFunc != nil { - return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain) +// AccountExists mock implementation of AccountExists from server.AccountManager interface +func (am *MockAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { + if am.AccountExistsFunc != nil { + return am.AccountExistsFunc(ctx, accountID) + } + return false, status.Errorf(codes.Unimplemented, "method AccountExists is not implemented") +} + +// GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface +func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) { + if am.GetAccountIDByUserIdFunc != nil { + return am.GetAccountIDByUserIdFunc(ctx, userId, domain) } return "", status.Errorf( codes.Unimplemented, - "method GetAccountIDByUserOrAccountID is not implemented", + "method GetAccountIDByUserID is not implemented", ) } @@ -435,7 +444,7 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID // CreateRoute mock implementation of CreateRoute from server.AccountManager interface func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { if am.CreateRouteFunc != nil { - return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups,accessControlGroupID, enabled, userID, keepRoute) + return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute) } return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 300f79c42..d9c359fba 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -775,7 +775,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { func createNSStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) + store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/peer.go b/management/server/peer.go index ad4d2658a..a85e8c6b2 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -707,6 +707,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) updateRemotePeers := false if login.UserID != "" { + if peer.UserID != login.UserID { + log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID) + return nil, nil, nil, status.Errorf(status.Unauthenticated, "invalid user") + } + changed, err := am.handleUserPeer(ctx, peer, settings) if err != nil { return nil, nil, nil, err diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 0eb782ed0..08592de2e 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1004,7 +1004,11 @@ func Test_RegisterPeerByUser(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + if err != nil { + t.Fatal(err) + } + defer cleanup() eventStore := &activity.InMemoryEventStore{} @@ -1065,7 +1069,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + if err != nil { + t.Fatal(err) + } + defer cleanup() eventStore := &activity.InMemoryEventStore{} @@ -1127,7 +1135,11 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + if err != nil { + t.Fatal(err) + } + defer cleanup() eventStore := &activity.InMemoryEventStore{} diff --git a/management/server/route_test.go b/management/server/route_test.go index 74bf9c3ec..ca2e99b8a 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1258,7 +1258,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { func createRouterStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) + store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -1738,7 +1738,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { } assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) - //peerD is also the routing peer for route1, should contain same routes firewall rules as peerA + // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) assert.Len(t, routesFirewallRules, 2) assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 85c68ef44..fe4dcafdb 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -10,6 +10,7 @@ import ( "path/filepath" "runtime" "runtime/debug" + "strconv" "strings" "sync" "time" @@ -63,8 +64,14 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr if err != nil { return nil, err } - conns := runtime.NumCPU() - sql.SetMaxOpenConns(conns) // TODO: make it configurable + + conns, err := strconv.Atoi(os.Getenv("NB_SQL_MAX_OPEN_CONNS")) + if err != nil { + conns = runtime.NumCPU() + } + sql.SetMaxOpenConns(conns) + + log.Infof("Set max open db connections to %d", conns) if err := migrate(ctx, db); err != nil { return nil, fmt.Errorf("migrate: %w", err) @@ -378,15 +385,26 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { Create(&usersToSave).Error } -// SaveGroups saves the given list of groups to the database. -// It updates existing groups if a conflict occurs. -func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { - groupsToSave := make([]nbgroup.Group, 0, len(groups)) - for _, group := range groups { - group.AccountID = accountID - groupsToSave = append(groupsToSave, *group) +// SaveUser saves the given user to the database. +func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error) } - return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error + return nil +} + +// SaveGroups saves the given list of groups to the database. +func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { + if len(groups) == 0 { + return nil + } + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) + } + return nil } // DeleteHashedPAT2TokenIDIndex is noop in SqlStore @@ -420,7 +438,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") } log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return accountID, nil @@ -433,7 +451,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (* if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.NewSetupKeyNotFoundError() + return nil, status.NewSetupKeyNotFoundError(result.Error) } if key.AccountID == "" { @@ -451,7 +469,7 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return token.ID, nil @@ -465,7 +483,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if token.UserID == "" { @@ -549,7 +567,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us @@ -612,7 +630,7 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if user.AccountID == "" { @@ -629,7 +647,7 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if peer.AccountID == "" { @@ -647,7 +665,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if peer.AccountID == "" { @@ -665,7 +683,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return accountID, nil @@ -678,7 +696,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return accountID, nil @@ -691,7 +709,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - return "", status.NewSetupKeyNotFoundError() + return "", status.NewSetupKeyNotFoundError(result.Error) } if accountID == "" { @@ -712,7 +730,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "no peers found for the account") } - return nil, status.Errorf(status.Internal, "issue getting IPs from store") + return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error) } // Convert the JSON strings to net.IP objects @@ -740,7 +758,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock return nil, status.Errorf(status.NotFound, "no peers found for the account") } log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting dns labels from store") + return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error) } return labels, nil @@ -753,7 +771,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } - return nil, status.Errorf(status.Internal, "issue getting network from store") + return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err) } return accountNetwork.Network, nil } @@ -765,7 +783,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "peer not found") } - return nil, status.Errorf(status.Internal, "issue getting peer from store") + return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error) } return &peer, nil @@ -777,7 +795,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } - return nil, status.Errorf(status.Internal, "issue getting settings from store") + return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err) } return accountSettings.Settings, nil } @@ -915,6 +933,28 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore, return store, nil } +// NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB. +func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { + store, err := NewPostgresqlStore(ctx, dsn, metrics) + if err != nil { + return nil, err + } + + err = store.SaveInstallationID(ctx, sqliteStore.GetInstallationID()) + if err != nil { + return nil, err + } + + for _, account := range sqliteStore.GetAllAccounts(ctx) { + err := store.SaveAccount(ctx, account) + if err != nil { + return nil, err + } + } + + return store, nil +} + func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { var setupKey SetupKey result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). @@ -923,7 +963,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "setup key not found") } - return nil, status.NewSetupKeyNotFoundError() + return nil, status.NewSetupKeyNotFoundError(result.Error) } return &setupKey, nil } @@ -955,7 +995,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group 'All' not found for account") } - return status.Errorf(status.Internal, "issue finding group 'All'") + return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error) } for _, existingPeerID := range group.Peers { @@ -967,7 +1007,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer group.Peers = append(group.Peers, peerID) if err := s.db.Save(&group).Error; err != nil { - return status.Errorf(status.Internal, "issue updating group 'All'") + return status.Errorf(status.Internal, "issue updating group 'All': %s", err) } return nil @@ -981,7 +1021,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group not found for account") } - return status.Errorf(status.Internal, "issue finding group") + return status.Errorf(status.Internal, "issue finding group: %s", result.Error) } for _, existingPeerID := range group.Peers { @@ -993,15 +1033,20 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId group.Peers = append(group.Peers, peerId) if err := s.db.Save(&group).Error; err != nil { - return status.Errorf(status.Internal, "issue updating group") + return status.Errorf(status.Internal, "issue updating group: %s", err) } return nil } +// GetUserPeers retrieves peers for a user. +func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { + return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID) +} + func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { - return status.Errorf(status.Internal, "issue adding peer to account") + return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } return nil @@ -1010,7 +1055,7 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { - return status.Errorf(status.Internal, "issue incrementing network serial count") + return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) } return nil } @@ -1105,6 +1150,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren return &group, nil } +// SaveGroup saves a group to the store. +func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error) + } + return nil +} + // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 64ef36831..4eed09c69 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -7,7 +7,6 @@ import ( "net" "net/netip" "os" - "path/filepath" "runtime" "testing" "time" @@ -25,7 +24,6 @@ import ( "github.com/netbirdio/netbird/management/server/status" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/util" ) func TestSqlite_NewStore(t *testing.T) { @@ -347,7 +345,11 @@ func TestSqlite_GetAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + if err != nil { + t.Fatal(err) + } + defer cleanup() id := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -367,7 +369,11 @@ func TestSqlite_SavePeer(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + if err != nil { + t.Fatal(err) + } + defer cleanup() account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -415,7 +421,11 @@ func TestSqlite_SavePeerStatus(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + defer cleanup() + if err != nil { + t.Fatal(err) + } account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -468,8 +478,11 @@ func TestSqlite_SavePeerLocation(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") - + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + defer cleanup() + if err != nil { + t.Fatal(err) + } account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -519,8 +532,11 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") - + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + defer cleanup() + if err != nil { + t.Fatal(err) + } existingDomain := "test.com" account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) @@ -539,8 +555,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") - + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + defer cleanup() + if err != nil { + t.Fatal(err) + } hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -560,8 +579,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") - + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + defer cleanup() + if err != nil { + t.Fatal(err) + } id := "9dj38s35-63fb-11ec-90d6-0242ac120003" user, err := store.GetUserByTokenID(context.Background(), id) @@ -668,24 +690,9 @@ func newSqliteStore(t *testing.T) *SqlStore { t.Helper() store, err := NewSqliteStore(context.Background(), t.TempDir(), nil) - require.NoError(t, err) - require.NotNil(t, store) - - return store -} - -func newSqliteStoreFromFile(t *testing.T, filename string) *SqlStore { - t.Helper() - - storeDir := t.TempDir() - - err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json")) - require.NoError(t, err) - - fStore, err := NewFileStore(context.Background(), storeDir, nil) - require.NoError(t, err) - - store, err := NewSqliteStoreFromFileStore(context.Background(), fStore, storeDir, nil) + t.Cleanup(func() { + store.Close(context.Background()) + }) require.NoError(t, err) require.NotNil(t, store) @@ -733,32 +740,31 @@ func newPostgresqlStore(t *testing.T) *SqlStore { return store } -func newPostgresqlStoreFromFile(t *testing.T, filename string) *SqlStore { +func newPostgresqlStoreFromSqlite(t *testing.T, filename string) *SqlStore { t.Helper() - storeDir := t.TempDir() - err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json")) - require.NoError(t, err) + store, cleanUpQ, err := NewSqliteTestStore(context.Background(), t.TempDir(), filename) + t.Cleanup(cleanUpQ) + if err != nil { + return nil + } - fStore, err := NewFileStore(context.Background(), storeDir, nil) - require.NoError(t, err) - - cleanUp, err := testutil.CreatePGDB() + cleanUpP, err := testutil.CreatePGDB() if err != nil { t.Fatal(err) } - t.Cleanup(cleanUp) + t.Cleanup(cleanUpP) postgresDsn, ok := os.LookupEnv(postgresDsnEnv) if !ok { t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv) } - store, err := NewPostgresqlStoreFromFileStore(context.Background(), fStore, postgresDsn, nil) + pstore, err := NewPostgresqlStoreFromSqlStore(context.Background(), store, postgresDsn, nil) require.NoError(t, err) require.NotNil(t, store) - return store + return pstore } func TestPostgresql_NewStore(t *testing.T) { @@ -924,7 +930,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) } - store := newPostgresqlStoreFromFile(t, "testdata/store.json") + store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -963,7 +969,7 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) } - store := newPostgresqlStoreFromFile(t, "testdata/store.json") + store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") existingDomain := "test.com" @@ -980,7 +986,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) } - store := newPostgresqlStoreFromFile(t, "testdata/store.json") + store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -995,7 +1001,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) } - store := newPostgresqlStoreFromFile(t, "testdata/store.json") + store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -1009,12 +1015,15 @@ func TestSqlite_GetTakenIPs(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + defer cleanup() + if err != nil { + t.Fatal(err) + } existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) @@ -1054,12 +1063,15 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + if err != nil { + return + } + t.Cleanup(cleanup) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) @@ -1096,12 +1108,15 @@ func TestSqlite_GetAccountNetwork(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID) @@ -1118,12 +1133,15 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") @@ -1137,12 +1155,16 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) + + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") @@ -1163,3 +1185,33 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { require.NoError(t, err) assert.Equal(t, 2, setupKey.UsedTimes) } + +func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + group := &nbgroup.Group{ + ID: "group-id", + AccountID: "account-id", + Name: "group-name", + Issued: "api", + Peers: nil, + } + err = store.ExecuteInTransaction(context.Background(), func(transaction Store) error { + err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group) + if err != nil { + t.Fatal("failed to save group") + return err + } + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID) + if err != nil { + t.Fatal("failed to get group") + return err + } + t.Logf("group: %v", group) + return nil + }) + assert.NoError(t, err) +} diff --git a/management/server/status/error.go b/management/server/status/error.go index d7fde35b9..29d185216 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -102,8 +102,12 @@ func NewPeerLoginExpiredError() error { } // NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key -func NewSetupKeyNotFoundError() error { - return Errorf(NotFound, "setup key not found") +func NewSetupKeyNotFoundError(err error) error { + return Errorf(NotFound, "setup key not found: %s", err) +} + +func NewGetAccountFromStoreError(err error) error { + return Errorf(Internal, "issue getting account from store: %s", err) } // NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store diff --git a/management/server/store.go b/management/server/store.go index f34a73c2d..50bc6afdf 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -12,10 +12,11 @@ import ( "strings" "time" - "github.com/netbirdio/netbird/dns" log "github.com/sirupsen/logrus" "gorm.io/gorm" + "github.com/netbirdio/netbird/dns" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/telemetry" @@ -59,6 +60,7 @@ type Store interface { GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) SaveUsers(accountID string, users map[string]*User) error + SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) DeleteHashedPAT2TokenIDIndex(hashedToken string) error @@ -67,7 +69,8 @@ type Store interface { GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) - SaveGroups(accountID string, groups map[string]*nbgroup.Group) error + SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error + SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) @@ -81,6 +84,7 @@ type Store interface { AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) + GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error @@ -236,23 +240,29 @@ func getMigrations(ctx context.Context) []migrationFunc { } } -// NewTestStoreFromJson is only used in tests -func NewTestStoreFromJson(ctx context.Context, dataDir string) (Store, func(), error) { - fstore, err := NewFileStore(ctx, dataDir, nil) - if err != nil { - return nil, nil, err - } - +// NewTestStoreFromSqlite is only used in tests +func NewTestStoreFromSqlite(ctx context.Context, filename string, dataDir string) (Store, func(), error) { // if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE kind := getStoreEngineFromEnv() if kind == "" { kind = SqliteStoreEngine } - var ( - store Store - cleanUp func() - ) + var store *SqlStore + var err error + var cleanUp func() + + if filename == "" { + store, err = NewSqliteStore(ctx, dataDir, nil) + cleanUp = func() { + store.Close(ctx) + } + } else { + store, cleanUp, err = NewSqliteTestStore(ctx, dataDir, filename) + } + if err != nil { + return nil, nil, err + } if kind == PostgresStoreEngine { cleanUp, err = testutil.CreatePGDB() @@ -265,21 +275,32 @@ func NewTestStoreFromJson(ctx context.Context, dataDir string) (Store, func(), e return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv) } - store, err = NewPostgresqlStoreFromFileStore(ctx, fstore, dsn, nil) + store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil) if err != nil { return nil, nil, err } - } else { - store, err = NewSqliteStoreFromFileStore(ctx, fstore, dataDir, nil) - if err != nil { - return nil, nil, err - } - cleanUp = func() { store.Close(ctx) } } return store, cleanUp, nil } +func NewSqliteTestStore(ctx context.Context, dataDir string, testFile string) (*SqlStore, func(), error) { + err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db")) + if err != nil { + return nil, nil, err + } + + store, err := NewSqliteStore(ctx, dataDir, nil) + if err != nil { + return nil, nil, err + } + + return store, func() { + store.Close(ctx) + os.Remove(filepath.Join(dataDir, "store.db")) + }, nil +} + // MigrateFileStoreToSqlite migrates the file store to the SQLite store. func MigrateFileStoreToSqlite(ctx context.Context, dataDir string) error { fileStorePath := path.Join(dataDir, storeFileName) diff --git a/management/server/store_test.go b/management/server/store_test.go index 40c36c9e0..fc821670d 100644 --- a/management/server/store_test.go +++ b/management/server/store_test.go @@ -14,12 +14,6 @@ type benchCase struct { size int } -var newFs = func(b *testing.B) Store { - b.Helper() - store, _ := NewFileStore(context.Background(), b.TempDir(), nil) - return store -} - var newSqlite = func(b *testing.B) Store { b.Helper() store, _ := NewSqliteStore(context.Background(), b.TempDir(), nil) @@ -28,13 +22,9 @@ var newSqlite = func(b *testing.B) Store { func BenchmarkTest_StoreWrite(b *testing.B) { cases := []benchCase{ - {name: "FileStore_Write", storeFn: newFs, size: 100}, {name: "SqliteStore_Write", storeFn: newSqlite, size: 100}, - {name: "FileStore_Write", storeFn: newFs, size: 500}, {name: "SqliteStore_Write", storeFn: newSqlite, size: 500}, - {name: "FileStore_Write", storeFn: newFs, size: 1000}, {name: "SqliteStore_Write", storeFn: newSqlite, size: 1000}, - {name: "FileStore_Write", storeFn: newFs, size: 2000}, {name: "SqliteStore_Write", storeFn: newSqlite, size: 2000}, } @@ -61,11 +51,8 @@ func BenchmarkTest_StoreWrite(b *testing.B) { func BenchmarkTest_StoreRead(b *testing.B) { cases := []benchCase{ - {name: "FileStore_Read", storeFn: newFs, size: 100}, {name: "SqliteStore_Read", storeFn: newSqlite, size: 100}, - {name: "FileStore_Read", storeFn: newFs, size: 500}, {name: "SqliteStore_Read", storeFn: newSqlite, size: 500}, - {name: "FileStore_Read", storeFn: newFs, size: 1000}, {name: "SqliteStore_Read", storeFn: newSqlite, size: 1000}, } @@ -89,3 +76,11 @@ func BenchmarkTest_StoreRead(b *testing.B) { }) } } + +func newStore(t *testing.T) Store { + t.Helper() + + store := newSqliteStore(t) + + return store +} diff --git a/management/server/testdata/extended-store.json b/management/server/testdata/extended-store.json deleted file mode 100644 index 7f96e57a8..000000000 --- a/management/server/testdata/extended-store.json +++ /dev/null @@ -1,120 +0,0 @@ -{ - "Accounts": { - "bf1c8084-ba50-4ce7-9439-34653001fc3b": { - "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", - "CreatedBy": "", - "Domain": "test.com", - "DomainCategory": "private", - "IsDomainPrimaryAccount": true, - "SetupKeys": { - "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { - "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "AccountID": "", - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "UpdatedAt": "0001-01-01T00:00:00Z", - "Revoked": false, - "UsedTimes": 0, - "LastUsed": "0001-01-01T00:00:00Z", - "AutoGroups": ["cfefqs706sqkneg59g2g"], - "UsageLimit": 0, - "Ephemeral": false - }, - "A2C8E62B-38F5-4553-B31E-DD66C696CEBC": { - "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC", - "AccountID": "", - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC", - "Name": "Faulty key with non existing group", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "UpdatedAt": "0001-01-01T00:00:00Z", - "Revoked": false, - "UsedTimes": 0, - "LastUsed": "0001-01-01T00:00:00Z", - "AutoGroups": ["abcd"], - "UsageLimit": 0, - "Ephemeral": false - } - }, - "Network": { - "id": "af1c8024-ha40-4ce2-9418-34653101fc3c", - "Net": { - "IP": "100.64.0.0", - "Mask": "//8AAA==" - }, - "Dns": "", - "Serial": 0 - }, - "Peers": {}, - "Users": { - "edafee4e-63fb-11ec-90d6-0242ac120003": { - "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "AccountID": "", - "Role": "admin", - "IsServiceUser": false, - "ServiceUserName": "", - "AutoGroups": ["cfefqs706sqkneg59g3g"], - "PATs": {}, - "Blocked": false, - "LastLogin": "0001-01-01T00:00:00Z" - }, - "f4f6d672-63fb-11ec-90d6-0242ac120003": { - "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "AccountID": "", - "Role": "user", - "IsServiceUser": false, - "ServiceUserName": "", - "AutoGroups": null, - "PATs": { - "9dj38s35-63fb-11ec-90d6-0242ac120003": { - "ID": "9dj38s35-63fb-11ec-90d6-0242ac120003", - "UserID": "", - "Name": "", - "HashedToken": "SoMeHaShEdToKeN", - "ExpirationDate": "2023-02-27T00:00:00Z", - "CreatedBy": "user", - "CreatedAt": "2023-01-01T00:00:00Z", - "LastUsed": "2023-02-01T00:00:00Z" - } - }, - "Blocked": false, - "LastLogin": "0001-01-01T00:00:00Z" - } - }, - "Groups": { - "cfefqs706sqkneg59g4g": { - "ID": "cfefqs706sqkneg59g4g", - "Name": "All", - "Peers": [] - }, - "cfefqs706sqkneg59g3g": { - "ID": "cfefqs706sqkneg59g3g", - "Name": "AwesomeGroup1", - "Peers": [] - }, - "cfefqs706sqkneg59g2g": { - "ID": "cfefqs706sqkneg59g2g", - "Name": "AwesomeGroup2", - "Peers": [] - } - }, - "Rules": null, - "Policies": [], - "Routes": null, - "NameServerGroups": null, - "DNSSettings": null, - "Settings": { - "PeerLoginExpirationEnabled": false, - "PeerLoginExpiration": 86400000000000, - "GroupsPropagationEnabled": false, - "JWTGroupsEnabled": false, - "JWTGroupsClaimName": "" - } - } - }, - "InstallationID": "" -} diff --git a/management/server/testdata/extended-store.sqlite b/management/server/testdata/extended-store.sqlite new file mode 100644 index 000000000..81aea8118 Binary files /dev/null and b/management/server/testdata/extended-store.sqlite differ diff --git a/management/server/testdata/store.json b/management/server/testdata/store.json deleted file mode 100644 index 6a8fc0712..000000000 --- a/management/server/testdata/store.json +++ /dev/null @@ -1,120 +0,0 @@ -{ - "Accounts": { - "bf1c8084-ba50-4ce7-9439-34653001fc3b": { - "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", - "CreatedBy": "", - "Domain": "test.com", - "DomainCategory": "private", - "IsDomainPrimaryAccount": true, - "SetupKeys": { - "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { - "Id": "", - "AccountID": "", - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "UpdatedAt": "0001-01-01T00:00:00Z", - "Revoked": false, - "UsedTimes": 0, - "LastUsed": "0001-01-01T00:00:00Z", - "AutoGroups": ["cq9bbkjjuspi5gd38epg"], - "UsageLimit": 0, - "Ephemeral": false - } - }, - "Network": { - "id": "af1c8024-ha40-4ce2-9418-34653101fc3c", - "Net": { - "IP": "100.64.0.0", - "Mask": "//8AAA==" - }, - "Dns": "", - "Serial": 0 - }, - "Peers": {}, - "Users": { - "edafee4e-63fb-11ec-90d6-0242ac120003": { - "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "AccountID": "", - "Role": "admin", - "IsServiceUser": false, - "ServiceUserName": "", - "AutoGroups": null, - "PATs": {}, - "Blocked": false, - "LastLogin": "0001-01-01T00:00:00Z" - }, - "f4f6d672-63fb-11ec-90d6-0242ac120003": { - "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "AccountID": "", - "Role": "user", - "IsServiceUser": false, - "ServiceUserName": "", - "AutoGroups": null, - "PATs": { - "9dj38s35-63fb-11ec-90d6-0242ac120003": { - "ID": "9dj38s35-63fb-11ec-90d6-0242ac120003", - "UserID": "", - "Name": "", - "HashedToken": "SoMeHaShEdToKeN", - "ExpirationDate": "2023-02-27T00:00:00Z", - "CreatedBy": "user", - "CreatedAt": "2023-01-01T00:00:00Z", - "LastUsed": "2023-02-01T00:00:00Z" - } - }, - "Blocked": false, - "LastLogin": "0001-01-01T00:00:00Z" - } - }, - "Groups": { - "cq9bbkjjuspi5gd38epg": { - "ID": "cq9bbkjjuspi5gd38epg", - "Name": "All", - "Peers": [] - } - }, - "Rules": null, - "Policies": [ - { - "ID": "cq9bbkjjuspi5gd38eq0", - "Name": "Default", - "Description": "This is a default rule that allows connections between all the resources", - "Enabled": true, - "Rules": [ - { - "ID": "cq9bbkjjuspi5gd38eq0", - "Name": "Default", - "Description": "This is a default rule that allows connections between all the resources", - "Enabled": true, - "Action": "accept", - "Destinations": [ - "cq9bbkjjuspi5gd38epg" - ], - "Sources": [ - "cq9bbkjjuspi5gd38epg" - ], - "Bidirectional": true, - "Protocol": "all", - "Ports": null - } - ], - "SourcePostureChecks": null - } - ], - "Routes": null, - "NameServerGroups": null, - "DNSSettings": null, - "Settings": { - "PeerLoginExpirationEnabled": false, - "PeerLoginExpiration": 86400000000000, - "GroupsPropagationEnabled": false, - "JWTGroupsEnabled": false, - "JWTGroupsClaimName": "" - } - } - }, - "InstallationID": "" -} \ No newline at end of file diff --git a/management/server/testdata/store.sqlite b/management/server/testdata/store.sqlite new file mode 100644 index 000000000..5fc746285 Binary files /dev/null and b/management/server/testdata/store.sqlite differ diff --git a/management/server/testdata/store_policy_migrate.json b/management/server/testdata/store_policy_migrate.json deleted file mode 100644 index 1b046e632..000000000 --- a/management/server/testdata/store_policy_migrate.json +++ /dev/null @@ -1,116 +0,0 @@ -{ - "Accounts": { - "bf1c8084-ba50-4ce7-9439-34653001fc3b": { - "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", - "Domain": "test.com", - "DomainCategory": "private", - "IsDomainPrimaryAccount": true, - "SetupKeys": { - "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "Revoked": false, - "UsedTimes": 0 - } - }, - "Network": { - "Id": "af1c8024-ha40-4ce2-9418-34653101fc3c", - "Net": { - "IP": "100.64.0.0", - "Mask": "//8AAA==" - }, - "Dns": null - }, - "Peers": { - "cfefqs706sqkneg59g4g": { - "ID": "cfefqs706sqkneg59g4g", - "Key": "MI5mHfJhbggPfD3FqEIsXm8X5bSWeUI2LhO9MpEEtWA=", - "SetupKey": "", - "IP": "100.103.179.238", - "Meta": { - "Hostname": "Ubuntu-2204-jammy-amd64-base", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "22.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "development", - "UIVersion": "" - }, - "Name": "crocodile", - "DNSLabel": "crocodile", - "Status": { - "LastSeen": "2023-02-13T12:37:12.635454796Z", - "Connected": true - }, - "UserID": "edafee4e-63fb-11ec-90d6-0242ac120003", - "SSHKey": "AAAAC3NzaC1lZDI1NTE5AAAAIJN1NM4bpB9K", - "SSHEnabled": false - }, - "cfeg6sf06sqkneg59g50": { - "ID": "cfeg6sf06sqkneg59g50", - "Key": "zMAOKUeIYIuun4n0xPR1b3IdYZPmsyjYmB2jWCuloC4=", - "SetupKey": "", - "IP": "100.103.26.180", - "Meta": { - "Hostname": "borg", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "22.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "development", - "UIVersion": "" - }, - "Name": "dingo", - "DNSLabel": "dingo", - "Status": { - "LastSeen": "2023-02-21T09:37:42.565899199Z", - "Connected": false - }, - "UserID": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "SSHKey": "AAAAC3NzaC1lZDI1NTE5AAAAILHW", - "SSHEnabled": true - } - }, - "Groups": { - "cfefqs706sqkneg59g3g": { - "ID": "cfefqs706sqkneg59g3g", - "Name": "All", - "Peers": [ - "cfefqs706sqkneg59g4g", - "cfeg6sf06sqkneg59g50" - ] - } - }, - "Rules": { - "cfefqs706sqkneg59g40": { - "ID": "cfefqs706sqkneg59g40", - "Name": "Default", - "Description": "This is a default rule that allows connections between all the resources", - "Disabled": false, - "Source": [ - "cfefqs706sqkneg59g3g" - ], - "Destination": [ - "cfefqs706sqkneg59g3g" - ], - "Flow": 0 - } - }, - "Users": { - "edafee4e-63fb-11ec-90d6-0242ac120003": { - "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "Role": "admin" - }, - "f4f6d672-63fb-11ec-90d6-0242ac120003": { - "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "Role": "user" - } - } - } - } -} diff --git a/management/server/testdata/store_policy_migrate.sqlite b/management/server/testdata/store_policy_migrate.sqlite new file mode 100644 index 000000000..0c1a491a6 Binary files /dev/null and b/management/server/testdata/store_policy_migrate.sqlite differ diff --git a/management/server/testdata/store_with_expired_peers.json b/management/server/testdata/store_with_expired_peers.json deleted file mode 100644 index 44c225682..000000000 --- a/management/server/testdata/store_with_expired_peers.json +++ /dev/null @@ -1,130 +0,0 @@ -{ - "Accounts": { - "bf1c8084-ba50-4ce7-9439-34653001fc3b": { - "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", - "Domain": "test.com", - "DomainCategory": "private", - "IsDomainPrimaryAccount": true, - "Settings": { - "PeerLoginExpirationEnabled": true, - "PeerLoginExpiration": 3600000000000 - }, - "SetupKeys": { - "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "Revoked": false, - "UsedTimes": 0 - - } - }, - "Network": { - "Id": "af1c8024-ha40-4ce2-9418-34653101fc3c", - "Net": { - "IP": "100.64.0.0", - "Mask": "//8AAA==" - }, - "Dns": null - }, - "Peers": { - "cfvprsrlo1hqoo49ohog": { - "ID": "cfvprsrlo1hqoo49ohog", - "Key": "5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=", - "SetupKey": "72546A29-6BC8-4311-BCFC-9CDBF33F1A48", - "IP": "100.64.114.31", - "Meta": { - "Hostname": "f2a34f6a4731", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "11", - "Platform": "unknown", - "OS": "Debian GNU/Linux", - "WtVersion": "0.12.0", - "UIVersion": "" - }, - "Name": "f2a34f6a4731", - "DNSLabel": "f2a34f6a4731", - "Status": { - "LastSeen": "2023-03-02T09:21:02.189035775+01:00", - "Connected": false, - "LoginExpired": false - }, - "UserID": "", - "SSHKey": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk", - "SSHEnabled": false, - "LoginExpirationEnabled": true, - "LastLogin": "2023-03-01T19:48:19.817799698+01:00" - }, - "cg05lnblo1hkg2j514p0": { - "ID": "cg05lnblo1hkg2j514p0", - "Key": "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=", - "SetupKey": "", - "IP": "100.64.39.54", - "Meta": { - "Hostname": "expiredhost", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "22.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "development", - "UIVersion": "" - }, - "Name": "expiredhost", - "DNSLabel": "expiredhost", - "Status": { - "LastSeen": "2023-03-02T09:19:57.276717255+01:00", - "Connected": false, - "LoginExpired": true - }, - "UserID": "edafee4e-63fb-11ec-90d6-0242ac120003", - "SSHKey": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK", - "SSHEnabled": false, - "LoginExpirationEnabled": true, - "LastLogin": "2023-03-02T09:14:21.791679181+01:00" - }, - "cg3161rlo1hs9cq94gdg": { - "ID": "cg3161rlo1hs9cq94gdg", - "Key": "mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=", - "SetupKey": "", - "IP": "100.64.117.96", - "Meta": { - "Hostname": "testhost", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "22.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "development", - "UIVersion": "" - }, - "Name": "testhost", - "DNSLabel": "testhost", - "Status": { - "LastSeen": "2023-03-06T18:21:27.252010027+01:00", - "Connected": false, - "LoginExpired": false - }, - "UserID": "edafee4e-63fb-11ec-90d6-0242ac120003", - "SSHKey": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM", - "SSHEnabled": false, - "LoginExpirationEnabled": false, - "LastLogin": "2023-03-07T09:02:47.442857106+01:00" - } - }, - "Users": { - "edafee4e-63fb-11ec-90d6-0242ac120003": { - "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "Role": "admin" - }, - "f4f6d672-63fb-11ec-90d6-0242ac120003": { - "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "Role": "user" - } - } - } - } -} \ No newline at end of file diff --git a/management/server/testdata/store_with_expired_peers.sqlite b/management/server/testdata/store_with_expired_peers.sqlite new file mode 100644 index 000000000..ed1133211 Binary files /dev/null and b/management/server/testdata/store_with_expired_peers.sqlite differ diff --git a/management/server/testdata/storev1.json b/management/server/testdata/storev1.json deleted file mode 100644 index 674b2b87a..000000000 --- a/management/server/testdata/storev1.json +++ /dev/null @@ -1,154 +0,0 @@ -{ - "Accounts": { - "auth0|61bf82ddeab084006aa1bccd": { - "Id": "auth0|61bf82ddeab084006aa1bccd", - "SetupKeys": { - "1B2B50B0-B3E8-4B0C-A426-525EDB8481BD": { - "Id": "831727121", - "Key": "1B2B50B0-B3E8-4B0C-A426-525EDB8481BD", - "Name": "One-off key", - "Type": "one-off", - "CreatedAt": "2021-12-24T16:09:45.926075752+01:00", - "ExpiresAt": "2022-01-23T16:09:45.926075752+01:00", - "Revoked": false, - "UsedTimes": 1, - "LastUsed": "2021-12-24T16:12:45.763424077+01:00" - }, - "EB51E9EB-A11F-4F6E-8E49-C982891B405A": { - "Id": "1769568301", - "Key": "EB51E9EB-A11F-4F6E-8E49-C982891B405A", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-12-24T16:09:45.926073628+01:00", - "ExpiresAt": "2022-01-23T16:09:45.926073628+01:00", - "Revoked": false, - "UsedTimes": 1, - "LastUsed": "2021-12-24T16:13:06.236748538+01:00" - } - }, - "Network": { - "Id": "a443c07a-5765-4a78-97fc-390d9c1d0e49", - "Net": { - "IP": "100.64.0.0", - "Mask": "/8AAAA==" - }, - "Dns": "" - }, - "Peers": { - "oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=": { - "Key": "oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=", - "SetupKey": "EB51E9EB-A11F-4F6E-8E49-C982891B405A", - "IP": "100.64.0.2", - "Meta": { - "Hostname": "braginini", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "21.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "" - }, - "Name": "braginini", - "Status": { - "LastSeen": "2021-12-24T16:13:11.244342541+01:00", - "Connected": false - } - }, - "xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=": { - "Key": "xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=", - "SetupKey": "1B2B50B0-B3E8-4B0C-A426-525EDB8481BD", - "IP": "100.64.0.1", - "Meta": { - "Hostname": "braginini", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "21.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "" - }, - "Name": "braginini", - "Status": { - "LastSeen": "2021-12-24T16:12:49.089339333+01:00", - "Connected": false - } - } - } - }, - "google-oauth2|103201118415301331038": { - "Id": "google-oauth2|103201118415301331038", - "SetupKeys": { - "5AFB60DB-61F2-4251-8E11-494847EE88E9": { - "Id": "2485964613", - "Key": "5AFB60DB-61F2-4251-8E11-494847EE88E9", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-12-24T16:10:02.238476+01:00", - "ExpiresAt": "2022-01-23T16:10:02.238476+01:00", - "Revoked": false, - "UsedTimes": 1, - "LastUsed": "2021-12-24T16:12:05.994307717+01:00" - }, - "A72E4DC2-00DE-4542-8A24-62945438104E": { - "Id": "3504804807", - "Key": "A72E4DC2-00DE-4542-8A24-62945438104E", - "Name": "One-off key", - "Type": "one-off", - "CreatedAt": "2021-12-24T16:10:02.238478209+01:00", - "ExpiresAt": "2022-01-23T16:10:02.238478209+01:00", - "Revoked": false, - "UsedTimes": 1, - "LastUsed": "2021-12-24T16:11:27.015741738+01:00" - } - }, - "Network": { - "Id": "b6d0b152-364e-40c1-a8a1-fa7bcac2267f", - "Net": { - "IP": "100.64.0.0", - "Mask": "/8AAAA==" - }, - "Dns": "" - }, - "Peers": { - "6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=": { - "Key": "6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=", - "SetupKey": "5AFB60DB-61F2-4251-8E11-494847EE88E9", - "IP": "100.64.0.2", - "Meta": { - "Hostname": "braginini", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "21.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "" - }, - "Name": "braginini", - "Status": { - "LastSeen": "2021-12-24T16:12:05.994305438+01:00", - "Connected": false - } - }, - "Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=": { - "Key": "Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=", - "SetupKey": "A72E4DC2-00DE-4542-8A24-62945438104E", - "IP": "100.64.0.1", - "Meta": { - "Hostname": "braginini", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "21.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "" - }, - "Name": "braginini", - "Status": { - "LastSeen": "2021-12-24T16:11:27.015739803+01:00", - "Connected": false - } - } - } - } - } -} \ No newline at end of file diff --git a/management/server/testdata/storev1.sqlite b/management/server/testdata/storev1.sqlite new file mode 100644 index 000000000..9a376698e Binary files /dev/null and b/management/server/testdata/storev1.sqlite differ diff --git a/management/server/user.go b/management/server/user.go index 7acb0b487..e40fc67eb 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -9,14 +9,14 @@ import ( "time" "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + log "github.com/sirupsen/logrus" ) const ( @@ -1274,6 +1274,74 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, hadPeers, nil } +// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. +func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd, + groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) { + + if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { + return + } + + userPeerIDMap := make(map[string]struct{}, len(peers)) + for _, peer := range peers { + userPeerIDMap[peer.ID] = struct{}{} + } + + for _, gid := range groupsToAdd { + group, ok := accountGroups[gid] + if !ok { + return nil, errors.New("group not found") + } + addUserPeersToGroup(userPeerIDMap, group) + groupsToUpdate = append(groupsToUpdate, group) + } + + for _, gid := range groupsToRemove { + group, ok := accountGroups[gid] + if !ok { + return nil, errors.New("group not found") + } + removeUserPeersFromGroup(userPeerIDMap, group) + groupsToUpdate = append(groupsToUpdate, group) + } + + return groupsToUpdate, nil +} + +// addUserPeersToGroup adds the user's peers to the group. +func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { + groupPeers := make(map[string]struct{}, len(group.Peers)) + for _, pid := range group.Peers { + groupPeers[pid] = struct{}{} + } + + for pid := range userPeerIDs { + groupPeers[pid] = struct{}{} + } + + group.Peers = make([]string, 0, len(groupPeers)) + for pid := range groupPeers { + group.Peers = append(group.Peers, pid) + } +} + +// removeUserPeersFromGroup removes user's peers from the group. +func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { + // skip removing peers from group All + if group.Name == "All" { + return + } + + updatedPeers := make([]string, 0, len(group.Peers)) + for _, pid := range group.Peers { + if _, found := userPeerIDs[pid]; !found { + updatedPeers = append(updatedPeers, pid) + } + } + + group.Peers = updatedPeers +} + func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { for _, user := range userData { if user.ID == userID { diff --git a/management/server/user_test.go b/management/server/user_test.go index c836ac98f..3f7f814a0 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -62,8 +62,10 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { assert.Equal(t, pat.CreatedBy, mockUserID) - fileStore := am.Store.(*FileStore) - tokenID := fileStore.HashedPAT2TokenID[pat.HashedToken] + tokenID, err := am.Store.GetTokenIDByHashedToken(context.Background(), pat.HashedToken) + if err != nil { + t.Fatalf("Error when getting token ID by hashed token: %s", err) + } if tokenID == "" { t.Fatal("GetTokenIDByHashedToken failed after adding PAT") @@ -71,11 +73,12 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { assert.Equal(t, pat.ID, tokenID) - userID := fileStore.TokenID2UserID[tokenID] - if userID == "" { - t.Fatal("GetUserByTokenId failed after adding PAT") + user, err := am.Store.GetUserByTokenID(context.Background(), tokenID) + if err != nil { + t.Fatalf("Error when getting user by token ID: %s", err) } - assert.Equal(t, mockUserID, userID) + + assert.Equal(t, mockUserID, user.Id) } func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { @@ -192,9 +195,12 @@ func TestUser_DeletePAT(t *testing.T) { t.Fatalf("Error when adding PAT to user: %s", err) } - assert.Nil(t, store.Accounts[mockAccountID].Users[mockUserID].PATs[mockTokenID1]) - assert.Empty(t, store.HashedPAT2TokenID[mockToken1]) - assert.Empty(t, store.TokenID2UserID[mockTokenID1]) + account, err = store.GetAccount(context.Background(), mockAccountID) + if err != nil { + t.Fatalf("Error when getting account: %s", err) + } + + assert.Nil(t, account.Users[mockUserID].PATs[mockTokenID1]) } func TestUser_GetPAT(t *testing.T) { @@ -353,13 +359,16 @@ func TestUser_CreateServiceUser(t *testing.T) { t.Fatalf("Error when creating service user: %s", err) } - assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users)) - assert.NotNil(t, store.Accounts[mockAccountID].Users[user.ID]) - assert.True(t, store.Accounts[mockAccountID].Users[user.ID].IsServiceUser) - assert.Equal(t, mockServiceUserName, store.Accounts[mockAccountID].Users[user.ID].ServiceUserName) - assert.Equal(t, UserRole(mockRole), store.Accounts[mockAccountID].Users[user.ID].Role) - assert.Equal(t, []string{"group1", "group2"}, store.Accounts[mockAccountID].Users[user.ID].AutoGroups) - assert.Equal(t, map[string]*PersonalAccessToken{}, store.Accounts[mockAccountID].Users[user.ID].PATs) + account, err = store.GetAccount(context.Background(), mockAccountID) + assert.NoError(t, err) + + assert.Equal(t, 2, len(account.Users)) + assert.NotNil(t, account.Users[user.ID]) + assert.True(t, account.Users[user.ID].IsServiceUser) + assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName) + assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role) + assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups) + assert.Equal(t, map[string]*PersonalAccessToken{}, account.Users[user.ID].PATs) assert.Zero(t, user.Email) assert.True(t, user.IsServiceUser) @@ -397,12 +406,15 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { t.Fatalf("Error when creating user: %s", err) } + account, err = store.GetAccount(context.Background(), mockAccountID) + assert.NoError(t, err) + assert.True(t, user.IsServiceUser) - assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users)) - assert.True(t, store.Accounts[mockAccountID].Users[user.ID].IsServiceUser) - assert.Equal(t, mockServiceUserName, store.Accounts[mockAccountID].Users[user.ID].ServiceUserName) - assert.Equal(t, UserRole(mockRole), store.Accounts[mockAccountID].Users[user.ID].Role) - assert.Equal(t, []string{"group1", "group2"}, store.Accounts[mockAccountID].Users[user.ID].AutoGroups) + assert.Equal(t, 2, len(account.Users)) + assert.True(t, account.Users[user.ID].IsServiceUser) + assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName) + assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role) + assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups) assert.Equal(t, mockServiceUserName, user.Name) assert.Equal(t, mockRole, user.Role) @@ -553,12 +565,15 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockServiceUserID) tt.assertErrFunc(t, err, tt.assertErrMessage) + account, err2 := store.GetAccount(context.Background(), mockAccountID) + assert.NoError(t, err2) + if err != nil { - assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users)) - assert.NotNil(t, store.Accounts[mockAccountID].Users[mockServiceUserID]) + assert.Equal(t, 2, len(account.Users)) + assert.NotNil(t, account.Users[mockServiceUserID]) } else { - assert.Equal(t, 1, len(store.Accounts[mockAccountID].Users)) - assert.Nil(t, store.Accounts[mockAccountID].Users[mockServiceUserID]) + assert.Equal(t, 1, len(account.Users)) + assert.Nil(t, account.Users[mockServiceUserID]) } }) } @@ -801,10 +816,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { assert.NoError(t, err) } - accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "") - assert.NoError(t, err) - - acc, err := am.Store.GetAccount(context.Background(), accID) + acc, err := am.Store.GetAccount(context.Background(), account.Id) assert.NoError(t, err) for _, id := range tc.expectedDeleted { diff --git a/util/file.go b/util/file.go index 8355488c9..ecaecd222 100644 --- a/util/file.go +++ b/util/file.go @@ -1,11 +1,15 @@ package util import ( + "bytes" "context" "encoding/json" + "fmt" "io" "os" "path/filepath" + "strings" + "text/template" log "github.com/sirupsen/logrus" ) @@ -160,6 +164,55 @@ func ReadJson(file string, res interface{}) (interface{}, error) { return res, nil } +// ReadJsonWithEnvSub reads JSON config file and maps to a provided interface with environment variable substitution +func ReadJsonWithEnvSub(file string, res interface{}) (interface{}, error) { + envVars := getEnvMap() + + f, err := os.Open(file) + if err != nil { + return nil, err + } + defer f.Close() + + bs, err := io.ReadAll(f) + if err != nil { + return nil, err + } + + t, err := template.New("").Parse(string(bs)) + if err != nil { + return nil, fmt.Errorf("error parsing template: %v", err) + } + + var output bytes.Buffer + // Execute the template, substituting environment variables + err = t.Execute(&output, envVars) + if err != nil { + return nil, fmt.Errorf("error executing template: %v", err) + } + + err = json.Unmarshal(output.Bytes(), &res) + if err != nil { + return nil, fmt.Errorf("failed parsing Json file after template was executed, err: %v", err) + } + + return res, nil +} + +// getEnvMap Convert the output of os.Environ() to a map +func getEnvMap() map[string]string { + envMap := make(map[string]string) + + for _, env := range os.Environ() { + parts := strings.SplitN(env, "=", 2) + if len(parts) == 2 { + envMap[parts[0]] = parts[1] + } + } + + return envMap +} + // CopyFileContents copies contents of the given src file to the dst file func CopyFileContents(src, dst string) (err error) { in, err := os.Open(src) diff --git a/util/file_suite_test.go b/util/file_suite_test.go new file mode 100644 index 000000000..3de7db49b --- /dev/null +++ b/util/file_suite_test.go @@ -0,0 +1,126 @@ +package util_test + +import ( + "crypto/md5" + "encoding/hex" + "io" + "os" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/netbirdio/netbird/util" +) + +var _ = Describe("Client", func() { + + var ( + tmpDir string + ) + + type TestConfig struct { + SomeMap map[string]string + SomeArray []string + SomeField int + } + + BeforeEach(func() { + var err error + tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*") + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + err := os.RemoveAll(tmpDir) + Expect(err).NotTo(HaveOccurred()) + }) + + Describe("Config", func() { + Context("in JSON format", func() { + It("should be written and read successfully", func() { + + m := make(map[string]string) + m["key1"] = "value1" + m["key2"] = "value2" + + arr := []string{"value1", "value2"} + + written := &TestConfig{ + SomeMap: m, + SomeArray: arr, + SomeField: 99, + } + + err := util.WriteJson(tmpDir+"/testconfig.json", written) + Expect(err).NotTo(HaveOccurred()) + + read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) + Expect(err).NotTo(HaveOccurred()) + Expect(read).NotTo(BeNil()) + Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"])) + Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"])) + Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr)) + Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField)) + + }) + }) + }) + + Describe("Copying file contents", func() { + Context("from one file to another", func() { + It("should be successful", func() { + + src := tmpDir + "/copytest_src" + dst := tmpDir + "/copytest_dst" + + err := util.WriteJson(src, []string{"1", "2", "3"}) + Expect(err).NotTo(HaveOccurred()) + + err = util.CopyFileContents(src, dst) + Expect(err).NotTo(HaveOccurred()) + + hashSrc := md5.New() + hashDst := md5.New() + + srcFile, err := os.Open(src) + Expect(err).NotTo(HaveOccurred()) + + dstFile, err := os.Open(dst) + Expect(err).NotTo(HaveOccurred()) + + _, err = io.Copy(hashSrc, srcFile) + Expect(err).NotTo(HaveOccurred()) + + _, err = io.Copy(hashDst, dstFile) + Expect(err).NotTo(HaveOccurred()) + + err = srcFile.Close() + Expect(err).NotTo(HaveOccurred()) + + err = dstFile.Close() + Expect(err).NotTo(HaveOccurred()) + + Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16]))) + }) + }) + }) + + Describe("Handle config file without full path", func() { + Context("config file handling", func() { + It("should be successful", func() { + written := &TestConfig{ + SomeField: 123, + } + cfgFile := "test_cfg.json" + defer os.Remove(cfgFile) + + err := util.WriteJson(cfgFile, written) + Expect(err).NotTo(HaveOccurred()) + + read, err := util.ReadJson(cfgFile, &TestConfig{}) + Expect(err).NotTo(HaveOccurred()) + Expect(read).NotTo(BeNil()) + }) + }) + }) +}) diff --git a/util/file_test.go b/util/file_test.go index 3de7db49b..1330e738e 100644 --- a/util/file_test.go +++ b/util/file_test.go @@ -1,126 +1,198 @@ -package util_test +package util import ( - "crypto/md5" - "encoding/hex" - "io" "os" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - - "github.com/netbirdio/netbird/util" + "reflect" + "strings" + "testing" ) -var _ = Describe("Client", func() { - - var ( - tmpDir string - ) - - type TestConfig struct { - SomeMap map[string]string - SomeArray []string - SomeField int +func TestReadJsonWithEnvSub(t *testing.T) { + type Config struct { + CertFile string `json:"CertFile"` + Credentials string `json:"Credentials"` + NestedOption struct { + URL string `json:"URL"` + } `json:"NestedOption"` } - BeforeEach(func() { - var err error - tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*") - Expect(err).NotTo(HaveOccurred()) - }) + type testCase struct { + name string + envVars map[string]string + jsonTemplate string + expectedResult Config + expectError bool + errorContains string + } - AfterEach(func() { - err := os.RemoveAll(tmpDir) - Expect(err).NotTo(HaveOccurred()) - }) + tests := []testCase{ + { + name: "All environment variables set", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + "URL": "https://env.testing.com", + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }}", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, + expectedResult: Config{ + CertFile: "/etc/certs/env_cert.crt", + Credentials: "env_credentials", + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: "https://env.testing.com", + }, + }, + expectError: false, + }, + { + name: "Missing environment variable", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + // "URL" is intentionally missing + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }}", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, + expectedResult: Config{ + CertFile: "/etc/certs/env_cert.crt", + Credentials: "env_credentials", + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: "", + }, + }, + expectError: false, + }, + { + name: "Invalid JSON template", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + "URL": "https://env.testing.com", + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, // Note the missing closing brace in "{{ .CREDENTIALS }" + expectedResult: Config{}, + expectError: true, + errorContains: "unexpected \"}\" in operand", + }, + { + name: "No substitutions", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + "URL": "https://env.testing.com", + }, + jsonTemplate: `{ + "CertFile": "/etc/certs/cert.crt", + "Credentials": "admnlknflkdasdf", + "NestedOption" : { + "URL": "https://testing.com" + } + }`, + expectedResult: Config{ + CertFile: "/etc/certs/cert.crt", + Credentials: "admnlknflkdasdf", + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: "https://testing.com", + }, + }, + expectError: false, + }, + { + name: "Should fail when Invalid characters in variables", + envVars: map[string]string{ + "CERT_FILE": `"/etc/certs/"cert".crt"`, + "CREDENTIALS": `env_credentia{ls}`, + "URL": `https://env.testing.com?param={{value}}`, + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }}", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, + expectedResult: Config{ + CertFile: `"/etc/certs/"cert".crt"`, + Credentials: `env_credentia{ls}`, + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: `https://env.testing.com?param={{value}}`, + }, + }, + expectError: true, + }, + } - Describe("Config", func() { - Context("in JSON format", func() { - It("should be written and read successfully", func() { + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + for key, value := range tc.envVars { + t.Setenv(key, value) + } - m := make(map[string]string) - m["key1"] = "value1" - m["key2"] = "value2" + tempFile, err := os.CreateTemp("", "config*.json") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } - arr := []string{"value1", "value2"} - - written := &TestConfig{ - SomeMap: m, - SomeArray: arr, - SomeField: 99, + defer func() { + err = os.Remove(tempFile.Name()) + if err != nil { + t.Logf("Failed to remove temp file: %v", err) } + }() - err := util.WriteJson(tmpDir+"/testconfig.json", written) - Expect(err).NotTo(HaveOccurred()) + _, err = tempFile.WriteString(tc.jsonTemplate) + if err != nil { + t.Fatalf("Failed to write to temp file: %v", err) + } + err = tempFile.Close() + if err != nil { + t.Fatalf("Failed to close temp file: %v", err) + } - read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) - Expect(err).NotTo(HaveOccurred()) - Expect(read).NotTo(BeNil()) - Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"])) - Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"])) - Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr)) - Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField)) + var result Config - }) - }) - }) + _, err = ReadJsonWithEnvSub(tempFile.Name(), &result) - Describe("Copying file contents", func() { - Context("from one file to another", func() { - It("should be successful", func() { - - src := tmpDir + "/copytest_src" - dst := tmpDir + "/copytest_dst" - - err := util.WriteJson(src, []string{"1", "2", "3"}) - Expect(err).NotTo(HaveOccurred()) - - err = util.CopyFileContents(src, dst) - Expect(err).NotTo(HaveOccurred()) - - hashSrc := md5.New() - hashDst := md5.New() - - srcFile, err := os.Open(src) - Expect(err).NotTo(HaveOccurred()) - - dstFile, err := os.Open(dst) - Expect(err).NotTo(HaveOccurred()) - - _, err = io.Copy(hashSrc, srcFile) - Expect(err).NotTo(HaveOccurred()) - - _, err = io.Copy(hashDst, dstFile) - Expect(err).NotTo(HaveOccurred()) - - err = srcFile.Close() - Expect(err).NotTo(HaveOccurred()) - - err = dstFile.Close() - Expect(err).NotTo(HaveOccurred()) - - Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16]))) - }) - }) - }) - - Describe("Handle config file without full path", func() { - Context("config file handling", func() { - It("should be successful", func() { - written := &TestConfig{ - SomeField: 123, + if tc.expectError { + if err == nil { + t.Fatalf("Expected error but got none") } - cfgFile := "test_cfg.json" - defer os.Remove(cfgFile) - - err := util.WriteJson(cfgFile, written) - Expect(err).NotTo(HaveOccurred()) - - read, err := util.ReadJson(cfgFile, &TestConfig{}) - Expect(err).NotTo(HaveOccurred()) - Expect(read).NotTo(BeNil()) - }) + if !strings.Contains(err.Error(), tc.errorContains) { + t.Errorf("Expected error containing '%s', but got '%v'", tc.errorContains, err) + } + } else { + if err != nil { + t.Fatalf("ReadJsonWithEnvSub failed: %v", err) + } + if !reflect.DeepEqual(result, tc.expectedResult) { + t.Errorf("Result does not match expected.\nGot: %+v\nExpected: %+v", result, tc.expectedResult) + } + } }) - }) -}) + } +}