diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml
index 524f35f6f..d6adcb27a 100644
--- a/.github/workflows/golang-test-linux.yml
+++ b/.github/workflows/golang-test-linux.yml
@@ -16,7 +16,7 @@ jobs:
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres']
- runs-on: ubuntu-latest
+ runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 7af6d3e4d..b2e2437e6 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -20,7 +20,7 @@ concurrency:
jobs:
release:
- runs-on: ubuntu-latest
+ runs-on: ubuntu-22.04
env:
flags: ""
steps:
diff --git a/README.md b/README.md
index aa3ec41e5..270c9ad87 100644
--- a/README.md
+++ b/README.md
@@ -49,6 +49,8 @@
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
+### NetBird on Lawrence Systems (Video)
+[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Key features
@@ -62,6 +64,7 @@
| | |
- - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)
| | |
| | | - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication) | | - - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)
|
| | | | | |
+
### Quickstart with NetBird Cloud
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go
index f0dc8bf21..033d1bb6a 100644
--- a/client/cmd/testutil_test.go
+++ b/client/cmd/testutil_test.go
@@ -3,7 +3,6 @@ package cmd
import (
"context"
"net"
- "path/filepath"
"testing"
"time"
@@ -34,18 +33,12 @@ func startTestingServices(t *testing.T) string {
if err != nil {
t.Fatal(err)
}
- testDir := t.TempDir()
- config.Datadir = testDir
- err = util.CopyFileContents("../testdata/store.json", filepath.Join(testDir, "store.json"))
- if err != nil {
- t.Fatal(err)
- }
_, signalLis := startSignal(t)
signalAddr := signalLis.Addr().String()
config.Signal.URI = signalAddr
- _, mgmLis := startManagement(t, config)
+ _, mgmLis := startManagement(t, config, "../testdata/store.sqlite")
mgmAddr := mgmLis.Addr().String()
return mgmAddr
}
@@ -70,7 +63,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
return s, lis
}
-func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
+func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", ":0")
@@ -78,7 +71,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
t.Fatal(err)
}
s := grpc.NewServer()
- store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
+ store, cleanUp, err := mgmt.NewTestStoreFromSqlite(context.Background(), testFile, t.TempDir())
if err != nil {
t.Fatal(err)
}
diff --git a/client/internal/connect.go b/client/internal/connect.go
index c77f95603..74dc1f1b5 100644
--- a/client/internal/connect.go
+++ b/client/internal/connect.go
@@ -269,12 +269,6 @@ func (c *ConnectClient) run(
checks := loginResp.GetChecks()
c.engineMutex.Lock()
- if c.engine != nil && c.engine.ctx.Err() != nil {
- log.Info("Stopping Netbird Engine")
- if err := c.engine.Stop(); err != nil {
- log.Errorf("Failed to stop engine: %v", err)
- }
- }
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engineMutex.Unlock()
@@ -294,6 +288,15 @@ func (c *ConnectClient) run(
}
<-engineCtx.Done()
+ c.engineMutex.Lock()
+ if c.engine != nil && c.engine.wgInterface != nil {
+ log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
+ if err := c.engine.Stop(); err != nil {
+ log.Errorf("Failed to stop engine: %v", err)
+ }
+ c.engine = nil
+ }
+ c.engineMutex.Unlock()
c.statusRecorder.ClientTeardown()
backOff.Reset()
diff --git a/client/internal/engine.go b/client/internal/engine.go
index c51901a22..eac8ec098 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -251,6 +251,13 @@ func (e *Engine) Stop() error {
}
log.Info("Network monitor: stopped")
+ // stop/restore DNS first so dbus and friends don't complain because of a missing interface
+ e.stopDNSServer()
+
+ if e.routeManager != nil {
+ e.routeManager.Stop()
+ }
+
err := e.removeAllPeers()
if err != nil {
return fmt.Errorf("failed to remove all peers: %s", err)
@@ -1116,18 +1123,12 @@ func (e *Engine) close() {
}
}
- // stop/restore DNS first so dbus and friends don't complain because of a missing interface
- e.stopDNSServer()
-
- if e.routeManager != nil {
- e.routeManager.Stop()
- }
-
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface != nil {
if err := e.wgInterface.Close(); err != nil {
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
}
+ e.wgInterface = nil
}
if !isNil(e.sshServer) {
@@ -1395,7 +1396,7 @@ func (e *Engine) startNetworkMonitor() {
}
// Set a new timer to debounce rapid network changes
- debounceTimer = time.AfterFunc(1*time.Second, func() {
+ debounceTimer = time.AfterFunc(2*time.Second, func() {
// This function is called after the debounce period
mu.Lock()
defer mu.Unlock()
@@ -1426,6 +1427,11 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
}
func (e *Engine) stopDNSServer() {
+ if e.dnsServer == nil {
+ return
+ }
+ e.dnsServer.Stop()
+ e.dnsServer = nil
err := fmt.Errorf("DNS server stopped")
nsGroupStates := e.statusRecorder.GetDNSStates()
for i := range nsGroupStates {
@@ -1433,10 +1439,6 @@ func (e *Engine) stopDNSServer() {
nsGroupStates[i].Error = err
}
e.statusRecorder.UpdateDNSStates(nsGroupStates)
- if e.dnsServer != nil {
- e.dnsServer.Stop()
- e.dnsServer = nil
- }
}
// isChecksEqual checks if two slices of checks are equal.
diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go
index 29a8439a2..3d1983c6b 100644
--- a/client/internal/engine_test.go
+++ b/client/internal/engine_test.go
@@ -6,7 +6,6 @@ import (
"net"
"net/netip"
"os"
- "path/filepath"
"runtime"
"strings"
"sync"
@@ -824,20 +823,6 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
func TestEngine_MultiplePeers(t *testing.T) {
// log.SetLevel(log.DebugLevel)
- dir := t.TempDir()
-
- err := util.CopyFileContents("../testdata/store.json", filepath.Join(dir, "store.json"))
- if err != nil {
- t.Fatal(err)
- }
- defer func() {
- err = os.Remove(filepath.Join(dir, "store.json")) //nolint
- if err != nil {
- t.Fatal(err)
- return
- }
- }()
-
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
@@ -847,7 +832,7 @@ func TestEngine_MultiplePeers(t *testing.T) {
return
}
defer sigServer.Stop()
- mgmtServer, mgmtAddr, err := startManagement(t, dir)
+ mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sqlite")
if err != nil {
t.Fatal(err)
return
@@ -1070,7 +1055,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
return s, lis.Addr().String(), nil
}
-func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) {
+func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
t.Helper()
config := &server.Config{
@@ -1095,7 +1080,7 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
- store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
+ store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), testFile, config.Datadir)
if err != nil {
return nil, "", err
}
diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go
index ad84bd700..0d4ad2396 100644
--- a/client/internal/peer/conn.go
+++ b/client/internal/peer/conn.go
@@ -32,6 +32,8 @@ const (
connPriorityRelay ConnPriority = 1
connPriorityICETurn ConnPriority = 1
connPriorityICEP2P ConnPriority = 2
+
+ reconnectMaxElapsedTime = 30 * time.Minute
)
type WgConfig struct {
@@ -83,6 +85,7 @@ type Conn struct {
wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy
signaler *Signaler
+ iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager
allowedIPsIP string
handshaker *Handshaker
@@ -108,6 +111,8 @@ type Conn struct {
// for reconnection operations
iCEDisconnected chan bool
relayDisconnected chan bool
+ connMonitor *ConnMonitor
+ reconnectCh <-chan struct{}
}
// NewConn creates a new not opened Conn to the remote peer.
@@ -123,21 +128,31 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
connLog := log.WithField("peer", config.Key)
var conn = &Conn{
- log: connLog,
- ctx: ctx,
- ctxCancel: ctxCancel,
- config: config,
- statusRecorder: statusRecorder,
- wgProxyFactory: wgProxyFactory,
- signaler: signaler,
- relayManager: relayManager,
- allowedIPsIP: allowedIPsIP.String(),
- statusRelay: NewAtomicConnStatus(),
- statusICE: NewAtomicConnStatus(),
+ log: connLog,
+ ctx: ctx,
+ ctxCancel: ctxCancel,
+ config: config,
+ statusRecorder: statusRecorder,
+ wgProxyFactory: wgProxyFactory,
+ signaler: signaler,
+ iFaceDiscover: iFaceDiscover,
+ relayManager: relayManager,
+ allowedIPsIP: allowedIPsIP.String(),
+ statusRelay: NewAtomicConnStatus(),
+ statusICE: NewAtomicConnStatus(),
+
iCEDisconnected: make(chan bool, 1),
relayDisconnected: make(chan bool, 1),
}
+ conn.connMonitor, conn.reconnectCh = NewConnMonitor(
+ signaler,
+ iFaceDiscover,
+ config,
+ conn.relayDisconnected,
+ conn.iCEDisconnected,
+ )
+
rFns := WorkerRelayCallbacks{
OnConnReady: conn.relayConnectionIsReady,
OnDisconnected: conn.onWorkerRelayStateDisconnected,
@@ -200,6 +215,8 @@ func (conn *Conn) startHandshakeAndReconnect() {
conn.log.Errorf("failed to send initial offer: %v", err)
}
+ go conn.connMonitor.Start(conn.ctx)
+
if conn.workerRelay.IsController() {
conn.reconnectLoopWithRetry()
} else {
@@ -309,12 +326,14 @@ func (conn *Conn) reconnectLoopWithRetry() {
// With it, we can decrease to send necessary offer
select {
case <-conn.ctx.Done():
+ return
case <-time.After(3 * time.Second):
}
ticker := conn.prepareExponentTicker()
defer ticker.Stop()
time.Sleep(1 * time.Second)
+
for {
select {
case t := <-ticker.C:
@@ -342,20 +361,11 @@ func (conn *Conn) reconnectLoopWithRetry() {
if err != nil {
conn.log.Errorf("failed to do handshake: %v", err)
}
- case changed := <-conn.relayDisconnected:
- if !changed {
- continue
- }
- conn.log.Debugf("Relay state changed, reset reconnect timer")
- ticker.Stop()
- ticker = conn.prepareExponentTicker()
- case changed := <-conn.iCEDisconnected:
- if !changed {
- continue
- }
- conn.log.Debugf("ICE state changed, reset reconnect timer")
+
+ case <-conn.reconnectCh:
ticker.Stop()
ticker = conn.prepareExponentTicker()
+
case <-conn.ctx.Done():
conn.log.Debugf("context is done, stop reconnect loop")
return
@@ -366,10 +376,10 @@ func (conn *Conn) reconnectLoopWithRetry() {
func (conn *Conn) prepareExponentTicker() *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
- RandomizationFactor: 0.01,
+ RandomizationFactor: 0.1,
Multiplier: 2,
MaxInterval: conn.config.Timeout,
- MaxElapsedTime: 0,
+ MaxElapsedTime: reconnectMaxElapsedTime,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, conn.ctx)
diff --git a/client/internal/peer/conn_monitor.go b/client/internal/peer/conn_monitor.go
new file mode 100644
index 000000000..75722c990
--- /dev/null
+++ b/client/internal/peer/conn_monitor.go
@@ -0,0 +1,212 @@
+package peer
+
+import (
+ "context"
+ "fmt"
+ "sync"
+ "time"
+
+ "github.com/pion/ice/v3"
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/internal/stdnet"
+)
+
+const (
+ signalerMonitorPeriod = 5 * time.Second
+ candidatesMonitorPeriod = 5 * time.Minute
+ candidateGatheringTimeout = 5 * time.Second
+)
+
+type ConnMonitor struct {
+ signaler *Signaler
+ iFaceDiscover stdnet.ExternalIFaceDiscover
+ config ConnConfig
+ relayDisconnected chan bool
+ iCEDisconnected chan bool
+ reconnectCh chan struct{}
+ currentCandidates []ice.Candidate
+ candidatesMu sync.Mutex
+}
+
+func NewConnMonitor(signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, config ConnConfig, relayDisconnected, iCEDisconnected chan bool) (*ConnMonitor, <-chan struct{}) {
+ reconnectCh := make(chan struct{}, 1)
+ cm := &ConnMonitor{
+ signaler: signaler,
+ iFaceDiscover: iFaceDiscover,
+ config: config,
+ relayDisconnected: relayDisconnected,
+ iCEDisconnected: iCEDisconnected,
+ reconnectCh: reconnectCh,
+ }
+ return cm, reconnectCh
+}
+
+func (cm *ConnMonitor) Start(ctx context.Context) {
+ signalerReady := make(chan struct{}, 1)
+ go cm.monitorSignalerReady(ctx, signalerReady)
+
+ localCandidatesChanged := make(chan struct{}, 1)
+ go cm.monitorLocalCandidatesChanged(ctx, localCandidatesChanged)
+
+ for {
+ select {
+ case changed := <-cm.relayDisconnected:
+ if !changed {
+ continue
+ }
+ log.Debugf("Relay state changed, triggering reconnect")
+ cm.triggerReconnect()
+
+ case changed := <-cm.iCEDisconnected:
+ if !changed {
+ continue
+ }
+ log.Debugf("ICE state changed, triggering reconnect")
+ cm.triggerReconnect()
+
+ case <-signalerReady:
+ log.Debugf("Signaler became ready, triggering reconnect")
+ cm.triggerReconnect()
+
+ case <-localCandidatesChanged:
+ log.Debugf("Local candidates changed, triggering reconnect")
+ cm.triggerReconnect()
+
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+func (cm *ConnMonitor) monitorSignalerReady(ctx context.Context, signalerReady chan<- struct{}) {
+ if cm.signaler == nil {
+ return
+ }
+
+ ticker := time.NewTicker(signalerMonitorPeriod)
+ defer ticker.Stop()
+
+ lastReady := true
+ for {
+ select {
+ case <-ticker.C:
+ currentReady := cm.signaler.Ready()
+ if !lastReady && currentReady {
+ select {
+ case signalerReady <- struct{}{}:
+ default:
+ }
+ }
+ lastReady = currentReady
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+func (cm *ConnMonitor) monitorLocalCandidatesChanged(ctx context.Context, localCandidatesChanged chan<- struct{}) {
+ ufrag, pwd, err := generateICECredentials()
+ if err != nil {
+ log.Warnf("Failed to generate ICE credentials: %v", err)
+ return
+ }
+
+ ticker := time.NewTicker(candidatesMonitorPeriod)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ if err := cm.handleCandidateTick(ctx, localCandidatesChanged, ufrag, pwd); err != nil {
+ log.Warnf("Failed to handle candidate tick: %v", err)
+ }
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+func (cm *ConnMonitor) handleCandidateTick(ctx context.Context, localCandidatesChanged chan<- struct{}, ufrag string, pwd string) error {
+ log.Debugf("Gathering ICE candidates")
+
+ transportNet, err := newStdNet(cm.iFaceDiscover, cm.config.ICEConfig.InterfaceBlackList)
+ if err != nil {
+ log.Errorf("failed to create pion's stdnet: %s", err)
+ }
+
+ agent, err := newAgent(cm.config, transportNet, candidateTypesP2P(), ufrag, pwd)
+ if err != nil {
+ return fmt.Errorf("create ICE agent: %w", err)
+ }
+ defer func() {
+ if err := agent.Close(); err != nil {
+ log.Warnf("Failed to close ICE agent: %v", err)
+ }
+ }()
+
+ gatherDone := make(chan struct{})
+ err = agent.OnCandidate(func(c ice.Candidate) {
+ log.Tracef("Got candidate: %v", c)
+ if c == nil {
+ close(gatherDone)
+ }
+ })
+ if err != nil {
+ return fmt.Errorf("set ICE candidate handler: %w", err)
+ }
+
+ if err := agent.GatherCandidates(); err != nil {
+ return fmt.Errorf("gather ICE candidates: %w", err)
+ }
+
+ ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout)
+ defer cancel()
+
+ select {
+ case <-ctx.Done():
+ return fmt.Errorf("wait for gathering: %w", ctx.Err())
+ case <-gatherDone:
+ }
+
+ candidates, err := agent.GetLocalCandidates()
+ if err != nil {
+ return fmt.Errorf("get local candidates: %w", err)
+ }
+ log.Tracef("Got candidates: %v", candidates)
+
+ if changed := cm.updateCandidates(candidates); changed {
+ select {
+ case localCandidatesChanged <- struct{}{}:
+ default:
+ }
+ }
+
+ return nil
+}
+
+func (cm *ConnMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
+ cm.candidatesMu.Lock()
+ defer cm.candidatesMu.Unlock()
+
+ if len(cm.currentCandidates) != len(newCandidates) {
+ cm.currentCandidates = newCandidates
+ return true
+ }
+
+ for i, candidate := range cm.currentCandidates {
+ if candidate.Address() != newCandidates[i].Address() {
+ cm.currentCandidates = newCandidates
+ return true
+ }
+ }
+
+ return false
+}
+
+func (cm *ConnMonitor) triggerReconnect() {
+ select {
+ case cm.reconnectCh <- struct{}{}:
+ default:
+ }
+}
diff --git a/client/internal/peer/stdnet.go b/client/internal/peer/stdnet.go
index ae31ebbf0..96d211dbc 100644
--- a/client/internal/peer/stdnet.go
+++ b/client/internal/peer/stdnet.go
@@ -6,6 +6,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
)
-func (w *WorkerICE) newStdNet() (*stdnet.Net, error) {
- return stdnet.NewNet(w.config.ICEConfig.InterfaceBlackList)
+func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
+ return stdnet.NewNet(ifaceBlacklist)
}
diff --git a/client/internal/peer/stdnet_android.go b/client/internal/peer/stdnet_android.go
index b411405bb..a39a03b1c 100644
--- a/client/internal/peer/stdnet_android.go
+++ b/client/internal/peer/stdnet_android.go
@@ -2,6 +2,6 @@ package peer
import "github.com/netbirdio/netbird/client/internal/stdnet"
-func (w *WorkerICE) newStdNet() (*stdnet.Net, error) {
- return stdnet.NewNetWithDiscover(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList)
+func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
+ return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist)
}
diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go
index c4e9d1950..c86c1858f 100644
--- a/client/internal/peer/worker_ice.go
+++ b/client/internal/peer/worker_ice.go
@@ -233,41 +233,16 @@ func (w *WorkerICE) Close() {
}
func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) {
- transportNet, err := w.newStdNet()
+ transportNet, err := newStdNet(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList)
if err != nil {
w.log.Errorf("failed to create pion's stdnet: %s", err)
}
- iceKeepAlive := iceKeepAlive()
- iceDisconnectedTimeout := iceDisconnectedTimeout()
- iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
-
- agentConfig := &ice.AgentConfig{
- MulticastDNSMode: ice.MulticastDNSModeDisabled,
- NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
- Urls: w.config.ICEConfig.StunTurn.Load().([]*stun.URI),
- CandidateTypes: relaySupport,
- InterfaceFilter: stdnet.InterfaceFilter(w.config.ICEConfig.InterfaceBlackList),
- UDPMux: w.config.ICEConfig.UDPMux,
- UDPMuxSrflx: w.config.ICEConfig.UDPMuxSrflx,
- NAT1To1IPs: w.config.ICEConfig.NATExternalIPs,
- Net: transportNet,
- FailedTimeout: &failedTimeout,
- DisconnectedTimeout: &iceDisconnectedTimeout,
- KeepaliveInterval: &iceKeepAlive,
- RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
- LocalUfrag: w.localUfrag,
- LocalPwd: w.localPwd,
- }
-
- if w.config.ICEConfig.DisableIPv6Discovery {
- agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
- }
-
w.sentExtraSrflx = false
- agent, err := ice.NewAgent(agentConfig)
+
+ agent, err := newAgent(w.config, transportNet, relaySupport, w.localUfrag, w.localPwd)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("create agent: %w", err)
}
err = agent.OnCandidate(w.onICECandidate)
@@ -390,6 +365,36 @@ func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferA
}
}
+func newAgent(config ConnConfig, transportNet *stdnet.Net, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) {
+ iceKeepAlive := iceKeepAlive()
+ iceDisconnectedTimeout := iceDisconnectedTimeout()
+ iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
+
+ agentConfig := &ice.AgentConfig{
+ MulticastDNSMode: ice.MulticastDNSModeDisabled,
+ NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
+ Urls: config.ICEConfig.StunTurn.Load().([]*stun.URI),
+ CandidateTypes: candidateTypes,
+ InterfaceFilter: stdnet.InterfaceFilter(config.ICEConfig.InterfaceBlackList),
+ UDPMux: config.ICEConfig.UDPMux,
+ UDPMuxSrflx: config.ICEConfig.UDPMuxSrflx,
+ NAT1To1IPs: config.ICEConfig.NATExternalIPs,
+ Net: transportNet,
+ FailedTimeout: &failedTimeout,
+ DisconnectedTimeout: &iceDisconnectedTimeout,
+ KeepaliveInterval: &iceKeepAlive,
+ RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
+ LocalUfrag: ufrag,
+ LocalPwd: pwd,
+ }
+
+ if config.ICEConfig.DisableIPv6Discovery {
+ agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
+ }
+
+ return ice.NewAgent(agentConfig)
+}
+
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress()
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
diff --git a/client/server/server_test.go b/client/server/server_test.go
index 9b18df4d3..e534ad7e2 100644
--- a/client/server/server_test.go
+++ b/client/server/server_test.go
@@ -110,7 +110,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
- store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
+ store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), "", config.Datadir)
if err != nil {
return nil, "", err
}
diff --git a/client/testdata/store.json b/client/testdata/store.json
deleted file mode 100644
index 8236f2703..000000000
--- a/client/testdata/store.json
+++ /dev/null
@@ -1,38 +0,0 @@
-{
- "Accounts": {
- "bf1c8084-ba50-4ce7-9439-34653001fc3b": {
- "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
- "SetupKeys": {
- "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
- "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
- "Name": "Default key",
- "Type": "reusable",
- "CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
- "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
- "Revoked": false,
- "UsedTimes": 0
-
- }
- },
- "Network": {
- "Id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
- "Net": {
- "IP": "100.64.0.0",
- "Mask": "//8AAA=="
- },
- "Dns": null
- },
- "Peers": {},
- "Users": {
- "edafee4e-63fb-11ec-90d6-0242ac120003": {
- "Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
- "Role": "admin"
- },
- "f4f6d672-63fb-11ec-90d6-0242ac120003": {
- "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
- "Role": "user"
- }
- }
- }
- }
-}
\ No newline at end of file
diff --git a/client/testdata/store.sqlite b/client/testdata/store.sqlite
new file mode 100644
index 000000000..118c2bebc
Binary files /dev/null and b/client/testdata/store.sqlite differ
diff --git a/management/client/client_test.go b/management/client/client_test.go
index a082e354b..313a67617 100644
--- a/management/client/client_test.go
+++ b/management/client/client_test.go
@@ -47,25 +47,18 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
level, _ := log.ParseLevel("debug")
log.SetLevel(level)
- testDir := t.TempDir()
-
config := &mgmt.Config{}
_, err := util.ReadJson("../server/testdata/management.json", config)
if err != nil {
t.Fatal(err)
}
- config.Datadir = testDir
- err = util.CopyFileContents("../server/testdata/store.json", filepath.Join(testDir, "store.json"))
- if err != nil {
- t.Fatal(err)
- }
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
s := grpc.NewServer()
- store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
+ store, cleanUp, err := NewSqliteTestStore(t, context.Background(), "../server/testdata/store.sqlite")
if err != nil {
t.Fatal(err)
}
@@ -521,3 +514,22 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match")
}
+
+func NewSqliteTestStore(t *testing.T, ctx context.Context, testFile string) (mgmt.Store, func(), error) {
+ t.Helper()
+ dataDir := t.TempDir()
+ err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db"))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ store, err := mgmt.NewSqliteStore(ctx, dataDir, nil)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return store, func() {
+ store.Close(ctx)
+ os.Remove(filepath.Join(dataDir, "store.db"))
+ }, nil
+}
diff --git a/management/cmd/management.go b/management/cmd/management.go
index 78b1a8d63..719d1a78c 100644
--- a/management/cmd/management.go
+++ b/management/cmd/management.go
@@ -475,7 +475,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) {
loadedConfig := &server.Config{}
- _, err := util.ReadJson(mgmtConfigPath, loadedConfig)
+ _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig)
if err != nil {
return nil, err
}
diff --git a/management/server/account.go b/management/server/account.go
index 1463ae033..6ee0015f8 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -20,6 +20,11 @@ import (
cacheStore "github.com/eko/gocache/v3/store"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
+ gocache "github.com/patrickmn/go-cache"
+ "github.com/rs/xid"
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/exp/maps"
+
"github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
@@ -36,10 +41,6 @@ import (
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
- gocache "github.com/patrickmn/go-cache"
- "github.com/rs/xid"
- log "github.com/sirupsen/logrus"
- "golang.org/x/exp/maps"
)
const (
@@ -76,7 +77,8 @@ type AccountManager interface {
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
- GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error)
+ AccountExists(ctx context.Context, accountID string) (bool, error)
+ GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
@@ -96,6 +98,7 @@ type AccountManager interface {
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
+ UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error)
GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
@@ -842,55 +845,54 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
return a.Peers[peerID]
}
-// SetJWTGroups updates the user's auto groups by synchronizing JWT groups.
-// Returns true if there are changes in the JWT group membership.
-func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool {
- user, ok := a.Users[userID]
- if !ok {
- return false
- }
-
+// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
+// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
+// newly groups to create and an error if any occurred.
+func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) {
existedGroupsByName := make(map[string]*nbgroup.Group)
- for _, group := range a.Groups {
+ for _, group := range groups {
existedGroupsByName[group.Name] = group
}
- newAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, a.Groups)
- groupsToAdd := difference(groupsNames, maps.Keys(jwtGroupsMap))
- groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupsNames)
+ newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups)
+
+ groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap))
+ groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames)
// If no groups are added or removed, we should not sync account
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
- return false
+ return false, nil, nil, nil
}
+ newGroupsToCreate := make([]*nbgroup.Group, 0)
+
var modified bool
for _, name := range groupsToAdd {
group, exists := existedGroupsByName[name]
if !exists {
group = &nbgroup.Group{
- ID: xid.New().String(),
- Name: name,
- Issued: nbgroup.GroupIssuedJWT,
+ ID: xid.New().String(),
+ AccountID: user.AccountID,
+ Name: name,
+ Issued: nbgroup.GroupIssuedJWT,
}
- a.Groups[group.ID] = group
+ newGroupsToCreate = append(newGroupsToCreate, group)
}
if group.Issued == nbgroup.GroupIssuedJWT {
- newAutoGroups = append(newAutoGroups, group.ID)
+ newUserAutoGroups = append(newUserAutoGroups, group.ID)
modified = true
}
}
for name, id := range jwtGroupsMap {
if !slices.Contains(groupsToRemove, name) {
- newAutoGroups = append(newAutoGroups, id)
+ newUserAutoGroups = append(newUserAutoGroups, id)
continue
}
modified = true
}
- user.AutoGroups = newAutoGroups
- return modified
+ return modified, newUserAutoGroups, newGroupsToCreate, nil
}
// UserGroupsAddToPeers adds groups to all peers of user
@@ -1261,37 +1263,36 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return nil
}
-// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided.
-// If an accountID is provided, it checks if the account exists and returns it.
-// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID.
+// AccountExists checks if an account exists.
+func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
+ return am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
+}
+
+// GetAccountIDByUserID retrieves the account ID based on the userID provided.
+// If user does have an account, it returns the user's account ID.
// If the user doesn't have an account, it creates one using the provided domain.
// Returns the account ID or an error if none is found or created.
-func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
- if accountID != "" {
- exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
- if err != nil {
- return "", err
- }
- if !exists {
- return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
- }
- return accountID, nil
+func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) {
+ if userID == "" {
+ return "", status.Errorf(status.NotFound, "no valid userID provided")
}
- if userID != "" {
- account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
- if err != nil {
- return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
- }
+ accountID, err := am.Store.GetAccountIDByUserID(userID)
+ if err != nil {
+ if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
+ account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
+ if err != nil {
+ return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
+ }
- if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
- return "", err
+ if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
+ return "", err
+ }
+ return account.Id, nil
}
-
- return account.Id, nil
+ return "", err
}
-
- return "", status.Errorf(status.NotFound, "no valid userID or accountID provided")
+ return accountID, nil
}
func isNil(i idp.Manager) bool {
@@ -1764,7 +1765,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
return nil, err
}
- if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
+ if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
}
@@ -1795,6 +1796,10 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
}
+ if user.AccountID != accountID {
+ return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID)
+ }
+
if !user.IsServiceUser && claims.Invited {
err = am.redeemInvite(ctx, accountID, user.Id)
if err != nil {
@@ -1802,7 +1807,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
}
}
- if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil {
+ if err = am.syncJWTGroups(ctx, accountID, claims); err != nil {
return "", "", err
}
@@ -1811,7 +1816,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled.
-func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error {
+func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
@@ -1822,69 +1827,134 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
if settings.JWTGroupsClaimName == "" {
- log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
+ log.WithContext(ctx).Debugf("JWT groups are enabled but no claim name is set")
return nil
}
- // TODO: Remove GetAccount after refactoring account peer's update
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
- if err != nil {
- return err
- }
-
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
- oldGroups := make([]string, len(user.AutoGroups))
- copy(oldGroups, user.AutoGroups)
+ unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID)
+ defer func() {
+ if unlockPeer != nil {
+ unlockPeer()
+ }
+ }()
- // Update the account if group membership changes
- if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
- addNewGroups := difference(user.AutoGroups, oldGroups)
- removeOldGroups := difference(oldGroups, user.AutoGroups)
-
- if settings.GroupsPropagationEnabled {
- account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
- account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
- account.Network.IncSerial()
+ var addNewGroups []string
+ var removeOldGroups []string
+ var hasChanges bool
+ var user *User
+ err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
+ user, err = am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
+ if err != nil {
+ return fmt.Errorf("error getting user: %w", err)
}
- if err := am.Store.SaveAccount(ctx, account); err != nil {
- log.WithContext(ctx).Errorf("failed to save account: %v", err)
+ groups, err := am.Store.GetAccountGroups(ctx, accountID)
+ if err != nil {
+ return fmt.Errorf("error getting account groups: %w", err)
+ }
+
+ changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, jwtGroupsNames)
+ if err != nil {
+ return fmt.Errorf("error getting JWT groups changes: %w", err)
+ }
+
+ hasChanges = changed
+ // skip update if no changes
+ if !changed {
return nil
}
+ if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
+ return fmt.Errorf("error saving groups: %w", err)
+ }
+
+ addNewGroups = difference(updatedAutoGroups, user.AutoGroups)
+ removeOldGroups = difference(user.AutoGroups, updatedAutoGroups)
+
+ user.AutoGroups = updatedAutoGroups
+ if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
+ return fmt.Errorf("error saving user: %w", err)
+ }
+
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
- log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
- if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
- am.updateAccountPeers(ctx, account)
+ groups, err = transaction.GetAccountGroups(ctx, accountID)
+ if err != nil {
+ return fmt.Errorf("error getting account groups: %w", err)
}
+
+ groupsMap := make(map[string]*nbgroup.Group, len(groups))
+ for _, group := range groups {
+ groupsMap[group.ID] = group
+ }
+
+ peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId)
+ if err != nil {
+ return fmt.Errorf("error getting user peers: %w", err)
+ }
+
+ updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
+ if err != nil {
+ return fmt.Errorf("error modifying user peers in groups: %w", err)
+ }
+
+ if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil {
+ return fmt.Errorf("error saving groups: %w", err)
+ }
+
+ if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
+ return fmt.Errorf("error incrementing network serial: %w", err)
+ }
+ }
+ unlockPeer()
+ unlockPeer = nil
+
+ return nil
+ })
+ if err != nil {
+ return err
+ }
+
+ if !hasChanges {
+ return nil
+ }
+
+ for _, g := range addNewGroups {
+ group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
+ if err != nil {
+ log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
+ } else {
+ meta := map[string]any{
+ "group": group.Name, "group_id": group.ID,
+ "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
+ }
+ am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta)
+ }
+ }
+
+ for _, g := range removeOldGroups {
+ group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
+ if err != nil {
+ log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
+ } else {
+ meta := map[string]any{
+ "group": group.Name, "group_id": group.ID,
+ "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
+ }
+ am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta)
+ }
+ }
+
+ if settings.GroupsPropagationEnabled {
+ account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
+ if err != nil {
+ return fmt.Errorf("error getting account: %w", err)
}
- for _, g := range addNewGroups {
- if group := account.GetGroup(g); group != nil {
- am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
- map[string]any{
- "group": group.Name,
- "group_id": group.ID,
- "is_service_user": user.IsServiceUser,
- "user_name": user.ServiceUserName})
- }
- }
-
- for _, g := range removeOldGroups {
- if group := account.GetGroup(g); group != nil {
- am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
- map[string]any{
- "group": group.Name,
- "group_id": group.ID,
- "is_service_user": user.IsServiceUser,
- "user_name": user.ServiceUserName})
- }
- }
+ log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
+ am.updateAccountPeers(ctx, account)
}
return nil
@@ -1917,7 +1987,17 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
// if Account ID is part of the claims
// it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
- return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
+ if claims.AccountId != "" {
+ exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId)
+ if err != nil {
+ return "", err
+ }
+ if !exists {
+ return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId)
+ }
+ return claims.AccountId, nil
+ }
+ return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
} else if claims.AccountId != "" {
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err != nil {
@@ -2230,7 +2310,11 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
routes := make(map[route.ID]*route.Route)
setupKeys := map[string]*SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
- users[userID] = NewOwnerUser(userID)
+
+ owner := NewOwnerUser(userID)
+ owner.AccountID = accountID
+ users[userID] = owner
+
dnsSettings := DNSSettings{
DisabledManagementGroups: make([]string, 0),
}
@@ -2298,12 +2382,17 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
// separateGroups separates user's auto groups into non-JWT and JWT groups.
// Returns the list of standard auto groups and a map of JWT auto groups,
// where the keys are the group names and the values are the group IDs.
-func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([]string, map[string]string) {
+func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) {
newAutoGroups := make([]string, 0)
jwtAutoGroups := make(map[string]string) // map of group name to group ID
+ allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups))
+ for _, group := range allGroups {
+ allGroupsMap[group.ID] = group
+ }
+
for _, id := range autoGroups {
- if group, ok := allGroups[id]; ok {
+ if group, ok := allGroupsMap[id]; ok {
if group.Issued == nbgroup.GroupIssuedJWT {
jwtAutoGroups[group.Name] = id
} else {
@@ -2311,5 +2400,6 @@ func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([
}
}
}
+
return newAutoGroups, jwtAutoGroups
}
diff --git a/management/server/account_test.go b/management/server/account_test.go
index e0fc94c88..9877a5510 100644
--- a/management/server/account_test.go
+++ b/management/server/account_test.go
@@ -633,7 +633,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
+ accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed")
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
@@ -671,17 +671,16 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
userId := "user-id"
domain := "test.domain"
- initAccount := newAccountWithId(context.Background(), "", userId, domain)
+ _ = newAccountWithId(context.Background(), "", userId, domain)
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- accountID := initAccount.Id
- accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
+ accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
- // that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated
+ // that happens inside the GetAccountIDByUserID where the id is getting generated
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
- initAccount, err = manager.Store.GetAccount(context.Background(), accountID)
+ initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get init account failed")
claims := jwtclaims.AuthorizationClaims{
@@ -885,7 +884,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
}
}
-func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
+func TestAccountManager_GetAccountByUserID(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
@@ -894,7 +893,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user"
- accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "")
+ accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, "")
if err != nil {
t.Fatal(err)
}
@@ -903,14 +902,13 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
return
}
- _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
- if err != nil {
- t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID)
- }
+ exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID)
+ assert.NoError(t, err)
+ assert.True(t, exists, "expected to get existing account after creation using userid")
- _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "")
+ _, err = manager.GetAccountIDByUserID(context.Background(), "", "")
if err == nil {
- t.Errorf("expected an error when user and account IDs are empty")
+ t.Errorf("expected an error when user ID is empty")
}
}
@@ -1731,7 +1729,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
+ accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
@@ -1746,7 +1744,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
+ _, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1758,7 +1756,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
})
require.NoError(t, err, "unable to add peer")
- accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
+ accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
@@ -1804,7 +1802,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
+ accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1832,7 +1830,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
},
}
- accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
+ accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
@@ -1852,7 +1850,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
+ _, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1864,7 +1862,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
})
require.NoError(t, err, "unable to add peer")
- accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
+ accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
@@ -1912,7 +1910,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
+ accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
@@ -1923,9 +1921,6 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
- accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
- require.NoError(t, err, "unable to get account by ID")
-
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings")
@@ -2261,8 +2256,12 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
}
func TestAccount_SetJWTGroups(t *testing.T) {
+ manager, err := createManager(t)
+ require.NoError(t, err, "unable to create account manager")
+
// create a new account
account := &Account{
+ Id: "accountID",
Peers: map[string]*nbpeer.Peer{
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
@@ -2273,62 +2272,120 @@ func TestAccount_SetJWTGroups(t *testing.T) {
Groups: map[string]*group.Group{
"group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}},
},
- Settings: &Settings{GroupsPropagationEnabled: true},
+ Settings: &Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"},
Users: map[string]*User{
- "user1": {Id: "user1"},
- "user2": {Id: "user2"},
+ "user1": {Id: "user1", AccountID: "accountID"},
+ "user2": {Id: "user2", AccountID: "accountID"},
},
}
+ assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
+
t.Run("empty jwt groups", func(t *testing.T) {
- updated := account.SetJWTGroups("user1", []string{})
- assert.False(t, updated, "account should not be updated")
- assert.Empty(t, account.Users["user1"].AutoGroups, "auto groups must be empty")
+ claims := jwtclaims.AuthorizationClaims{
+ UserId: "user1",
+ Raw: jwt.MapClaims{"groups": []interface{}{}},
+ }
+ err := manager.syncJWTGroups(context.Background(), "accountID", claims)
+ assert.NoError(t, err, "unable to sync jwt groups")
+
+ user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
+ assert.NoError(t, err, "unable to get user")
+ assert.Empty(t, user.AutoGroups, "auto groups must be empty")
})
t.Run("jwt match existing api group", func(t *testing.T) {
- updated := account.SetJWTGroups("user1", []string{"group1"})
- assert.False(t, updated, "account should not be updated")
- assert.Equal(t, 0, len(account.Users["user1"].AutoGroups))
- assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued")
+ claims := jwtclaims.AuthorizationClaims{
+ UserId: "user1",
+ Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
+ }
+ err := manager.syncJWTGroups(context.Background(), "accountID", claims)
+ assert.NoError(t, err, "unable to sync jwt groups")
+
+ user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
+ assert.NoError(t, err, "unable to get user")
+ assert.Len(t, user.AutoGroups, 0)
+
+ group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
+ assert.NoError(t, err, "unable to get group")
+ assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})
t.Run("jwt match existing api group in user auto groups", func(t *testing.T) {
account.Users["user1"].AutoGroups = []string{"group1"}
+ assert.NoError(t, manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, account.Users["user1"]))
- updated := account.SetJWTGroups("user1", []string{"group1"})
- assert.False(t, updated, "account should not be updated")
- assert.Equal(t, 1, len(account.Users["user1"].AutoGroups))
- assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued")
+ claims := jwtclaims.AuthorizationClaims{
+ UserId: "user1",
+ Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
+ }
+ err = manager.syncJWTGroups(context.Background(), "accountID", claims)
+ assert.NoError(t, err, "unable to sync jwt groups")
+
+ user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
+ assert.NoError(t, err, "unable to get user")
+ assert.Len(t, user.AutoGroups, 1)
+
+ group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
+ assert.NoError(t, err, "unable to get group")
+ assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})
t.Run("add jwt group", func(t *testing.T) {
- updated := account.SetJWTGroups("user1", []string{"group1", "group2"})
- assert.True(t, updated, "account should be updated")
- assert.Len(t, account.Groups, 2, "new group should be added")
- assert.Len(t, account.Users["user1"].AutoGroups, 2, "new group should be added")
- assert.Contains(t, account.Groups, account.Users["user1"].AutoGroups[0], "groups must contain group2 from user groups")
+ claims := jwtclaims.AuthorizationClaims{
+ UserId: "user1",
+ Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group2"}},
+ }
+ err = manager.syncJWTGroups(context.Background(), "accountID", claims)
+ assert.NoError(t, err, "unable to sync jwt groups")
+
+ user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
+ assert.NoError(t, err, "unable to get user")
+ assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
})
t.Run("existed group not update", func(t *testing.T) {
- updated := account.SetJWTGroups("user1", []string{"group2"})
- assert.False(t, updated, "account should not be updated")
- assert.Len(t, account.Groups, 2, "groups count should not be changed")
+ claims := jwtclaims.AuthorizationClaims{
+ UserId: "user1",
+ Raw: jwt.MapClaims{"groups": []interface{}{"group2"}},
+ }
+ err = manager.syncJWTGroups(context.Background(), "accountID", claims)
+ assert.NoError(t, err, "unable to sync jwt groups")
+
+ user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
+ assert.NoError(t, err, "unable to get user")
+ assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
})
t.Run("add new group", func(t *testing.T) {
- updated := account.SetJWTGroups("user2", []string{"group1", "group3"})
- assert.True(t, updated, "account should be updated")
- assert.Len(t, account.Groups, 3, "new group should be added")
- assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added")
- assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups")
+ claims := jwtclaims.AuthorizationClaims{
+ UserId: "user2",
+ Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group3"}},
+ }
+ err = manager.syncJWTGroups(context.Background(), "accountID", claims)
+ assert.NoError(t, err, "unable to sync jwt groups")
+
+ groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID")
+ assert.NoError(t, err)
+ assert.Len(t, groups, 3, "new group3 should be added")
+
+ user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user2")
+ assert.NoError(t, err, "unable to get user")
+ assert.Len(t, user.AutoGroups, 1, "new group should be added")
})
t.Run("remove all JWT groups", func(t *testing.T) {
- updated := account.SetJWTGroups("user1", []string{})
- assert.True(t, updated, "account should be updated")
- assert.Len(t, account.Users["user1"].AutoGroups, 1, "only non-JWT groups should remain")
- assert.Contains(t, account.Users["user1"].AutoGroups, "group1", " group1 should still be present")
+ claims := jwtclaims.AuthorizationClaims{
+ UserId: "user1",
+ Raw: jwt.MapClaims{"groups": []interface{}{}},
+ }
+ err = manager.syncJWTGroups(context.Background(), "accountID", claims)
+ assert.NoError(t, err, "unable to sync jwt groups")
+
+ user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
+ assert.NoError(t, err, "unable to get user")
+ assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
+ assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present")
})
}
@@ -2428,7 +2485,7 @@ func createManager(t TB) (*DefaultAccountManager, error) {
func createStore(t TB) (Store, error) {
t.Helper()
dataDir := t.TempDir()
- store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
+ store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
if err != nil {
return nil, err
}
diff --git a/management/server/dns_test.go b/management/server/dns_test.go
index 53ab1eaaf..631a19785 100644
--- a/management/server/dns_test.go
+++ b/management/server/dns_test.go
@@ -212,7 +212,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
func createDNSStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
- store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
+ store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
if err != nil {
return nil, err
}
diff --git a/management/server/file_store.go b/management/server/file_store.go
index 994a4b1ee..df3e9bb77 100644
--- a/management/server/file_store.go
+++ b/management/server/file_store.go
@@ -2,24 +2,18 @@ package server
import (
"context"
- "errors"
- "net"
"os"
"path/filepath"
"strings"
"sync"
"time"
- "github.com/netbirdio/netbird/dns"
- nbgroup "github.com/netbirdio/netbird/management/server/group"
- nbpeer "github.com/netbirdio/netbird/management/server/peer"
- "github.com/netbirdio/netbird/management/server/posture"
- "github.com/netbirdio/netbird/management/server/status"
- "github.com/netbirdio/netbird/management/server/telemetry"
- "github.com/netbirdio/netbird/route"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
+ nbgroup "github.com/netbirdio/netbird/management/server/group"
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util"
)
@@ -42,167 +36,9 @@ type FileStore struct {
mux sync.Mutex `json:"-"`
storeFile string `json:"-"`
- // sync.Mutex indexed by resource ID
- resourceLocks sync.Map `json:"-"`
- globalAccountLock sync.Mutex `json:"-"`
-
metrics telemetry.AppMetrics `json:"-"`
}
-func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error {
- return f(s)
-}
-
-func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)]
- if !ok {
- return status.NewSetupKeyNotFoundError()
- }
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return err
- }
-
- account.SetupKeys[setupKeyID].UsedTimes++
-
- return s.SaveAccount(ctx, account)
-}
-
-func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return err
- }
-
- allGroup, err := account.GetGroupAll()
- if err != nil || allGroup == nil {
- return errors.New("all group not found")
- }
-
- allGroup.Peers = append(allGroup.Peers, peerID)
-
- return nil
-}
-
-func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- account, err := s.getAccount(accountId)
- if err != nil {
- return err
- }
-
- account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId)
-
- return nil
-}
-
-func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- account, ok := s.Accounts[peer.AccountID]
- if !ok {
- return status.NewAccountNotFoundError(peer.AccountID)
- }
-
- account.Peers[peer.ID] = peer
- return s.SaveAccount(ctx, account)
-}
-
-func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- account, ok := s.Accounts[accountId]
- if !ok {
- return status.NewAccountNotFoundError(accountId)
- }
-
- account.Network.Serial++
-
- return s.SaveAccount(ctx, account)
-}
-
-func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)]
- if !ok {
- return nil, status.NewSetupKeyNotFoundError()
- }
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return nil, err
- }
-
- setupKey, ok := account.SetupKeys[key]
- if !ok {
- return nil, status.Errorf(status.NotFound, "setup key not found")
- }
-
- return setupKey, nil
-}
-
-func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return nil, err
- }
-
- var takenIps []net.IP
- for _, existingPeer := range account.Peers {
- takenIps = append(takenIps, existingPeer.IP)
- }
-
- return takenIps, nil
-}
-
-func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return nil, err
- }
-
- existingLabels := []string{}
- for _, peer := range account.Peers {
- if peer.DNSLabel != "" {
- existingLabels = append(existingLabels, peer.DNSLabel)
- }
- }
- return existingLabels, nil
-}
-
-func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return nil, err
- }
-
- return account.Network, nil
-}
-
-type StoredAccount struct{}
-
// NewFileStore restores a store from the file located in the datadir
func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
fs, err := restore(ctx, filepath.Join(dataDir, storeFileName))
@@ -213,25 +49,6 @@ func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetr
return fs, nil
}
-// NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir
-func NewFilestoreFromSqliteStore(ctx context.Context, sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
- store, err := NewFileStore(ctx, dataDir, metrics)
- if err != nil {
- return nil, err
- }
-
- err = store.SaveInstallationID(ctx, sqlStore.GetInstallationID())
- if err != nil {
- return nil, err
- }
-
- for _, account := range sqlStore.GetAllAccounts(ctx) {
- store.Accounts[account.Id] = account
- }
-
- return store, store.persist(ctx, store.storeFile)
-}
-
// restore the state of the store from the file.
// Creates a new empty store file if doesn't exist
func restore(ctx context.Context, file string) (*FileStore, error) {
@@ -240,7 +57,6 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
s := &FileStore{
Accounts: make(map[string]*Account),
mux: sync.Mutex{},
- globalAccountLock: sync.Mutex{},
SetupKeyID2AccountID: make(map[string]string),
PeerKeyID2AccountID: make(map[string]string),
UserID2AccountID: make(map[string]string),
@@ -416,252 +232,6 @@ func (s *FileStore) persist(ctx context.Context, file string) error {
return nil
}
-// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
-func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
- log.WithContext(ctx).Debugf("acquiring global lock")
- start := time.Now()
- s.globalAccountLock.Lock()
-
- unlock = func() {
- s.globalAccountLock.Unlock()
- log.WithContext(ctx).Debugf("released global lock in %v", time.Since(start))
- }
-
- took := time.Since(start)
- log.WithContext(ctx).Debugf("took %v to acquire global lock", took)
- if s.metrics != nil {
- s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took)
- }
-
- return unlock
-}
-
-// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
-func (s *FileStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
- log.WithContext(ctx).Debugf("acquiring lock for ID %s", uniqueID)
- start := time.Now()
- value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.Mutex{})
- mtx := value.(*sync.Mutex)
- mtx.Lock()
-
- unlock = func() {
- mtx.Unlock()
- log.WithContext(ctx).Debugf("released lock for ID %s in %v", uniqueID, time.Since(start))
- }
-
- return unlock
-}
-
-// AcquireReadLockByUID acquires an ID lock for reading a resource and returns a function that releases the lock
-// This method is still returns a write lock as file store can't handle read locks
-func (s *FileStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
- return s.AcquireWriteLockByUID(ctx, uniqueID)
-}
-
-func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- if account.Id == "" {
- return status.Errorf(status.InvalidArgument, "account id should not be empty")
- }
-
- accountCopy := account.Copy()
-
- s.Accounts[accountCopy.Id] = accountCopy
-
- // todo check that account.Id and keyId are not exist already
- // because if keyId exists for other accounts this can be bad
- for keyID := range accountCopy.SetupKeys {
- s.SetupKeyID2AccountID[strings.ToUpper(keyID)] = accountCopy.Id
- }
-
- // enforce peer to account index and delete peer to route indexes for rebuild
- for _, peer := range accountCopy.Peers {
- s.PeerKeyID2AccountID[peer.Key] = accountCopy.Id
- s.PeerID2AccountID[peer.ID] = accountCopy.Id
- }
-
- for _, user := range accountCopy.Users {
- s.UserID2AccountID[user.Id] = accountCopy.Id
- for _, pat := range user.PATs {
- s.TokenID2UserID[pat.ID] = user.Id
- s.HashedPAT2TokenID[pat.HashedToken] = pat.ID
- }
- }
-
- if accountCopy.DomainCategory == PrivateCategory && accountCopy.IsDomainPrimaryAccount {
- s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id
- }
-
- return s.persist(ctx, s.storeFile)
-}
-
-func (s *FileStore) DeleteAccount(ctx context.Context, account *Account) error {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- if account.Id == "" {
- return status.Errorf(status.InvalidArgument, "account id should not be empty")
- }
-
- for keyID := range account.SetupKeys {
- delete(s.SetupKeyID2AccountID, strings.ToUpper(keyID))
- }
-
- // enforce peer to account index and delete peer to route indexes for rebuild
- for _, peer := range account.Peers {
- delete(s.PeerKeyID2AccountID, peer.Key)
- delete(s.PeerID2AccountID, peer.ID)
- }
-
- for _, user := range account.Users {
- for _, pat := range user.PATs {
- delete(s.TokenID2UserID, pat.ID)
- delete(s.HashedPAT2TokenID, pat.HashedToken)
- }
- delete(s.UserID2AccountID, user.Id)
- }
-
- if account.DomainCategory == PrivateCategory && account.IsDomainPrimaryAccount {
- delete(s.PrivateDomain2AccountID, account.Domain)
- }
-
- delete(s.Accounts, account.Id)
-
- return s.persist(ctx, s.storeFile)
-}
-
-// DeleteHashedPAT2TokenIDIndex removes an entry from the indexing map HashedPAT2TokenID
-func (s *FileStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- delete(s.HashedPAT2TokenID, hashedToken)
-
- return nil
-}
-
-// DeleteTokenID2UserIDIndex removes an entry from the indexing map TokenID2UserID
-func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- delete(s.TokenID2UserID, tokenID)
-
- return nil
-}
-
-// GetAccountByPrivateDomain returns account by private domain
-func (s *FileStore) GetAccountByPrivateDomain(_ context.Context, domain string) (*Account, error) {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- accountID, ok := s.PrivateDomain2AccountID[strings.ToLower(domain)]
- if !ok {
- return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
- }
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return nil, err
- }
-
- return account.Copy(), nil
-}
-
-// GetAccountBySetupKey returns account by setup key id
-func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*Account, error) {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
- if !ok {
- return nil, status.NewSetupKeyNotFoundError()
- }
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return nil, err
- }
-
- return account.Copy(), nil
-}
-
-// GetTokenIDByHashedToken returns the id of a personal access token by its hashed secret
-func (s *FileStore) GetTokenIDByHashedToken(_ context.Context, token string) (string, error) {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- tokenID, ok := s.HashedPAT2TokenID[token]
- if !ok {
- return "", status.Errorf(status.NotFound, "tokenID not found: provided token doesn't exists")
- }
-
- return tokenID, nil
-}
-
-// GetUserByTokenID returns a User object a tokenID belongs to
-func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User, error) {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- userID, ok := s.TokenID2UserID[tokenID]
- if !ok {
- return nil, status.Errorf(status.NotFound, "user not found: provided tokenID doesn't exists")
- }
-
- accountID, ok := s.UserID2AccountID[userID]
- if !ok {
- return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
- }
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return nil, err
- }
-
- return account.Users[userID].Copy(), nil
-}
-
-func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) {
- accountID, ok := s.UserID2AccountID[userID]
- if !ok {
- return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
- }
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return nil, err
- }
-
- user := account.Users[userID].Copy()
- pat := make([]PersonalAccessToken, 0, len(user.PATs))
- for _, token := range user.PATs {
- if token != nil {
- pat = append(pat, *token)
- }
- }
- user.PATsG = pat
-
- return user, nil
-}
-
-func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) {
- account, err := s.getAccount(accountID)
- if err != nil {
- return nil, err
- }
-
- groupsSlice := make([]*nbgroup.Group, 0, len(account.Groups))
-
- for _, group := range account.Groups {
- groupsSlice = append(groupsSlice, group)
- }
-
- return groupsSlice, nil
-}
-
// GetAllAccounts returns all accounts
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
s.mux.Lock()
@@ -673,278 +243,6 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
return all
}
-// getAccount returns a reference to the Account. Should not return a copy.
-func (s *FileStore) getAccount(accountID string) (*Account, error) {
- account, ok := s.Accounts[accountID]
- if !ok {
- return nil, status.NewAccountNotFoundError(accountID)
- }
-
- return account, nil
-}
-
-// GetAccount returns an account for ID
-func (s *FileStore) GetAccount(_ context.Context, accountID string) (*Account, error) {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return nil, err
- }
-
- return account.Copy(), nil
-}
-
-// GetAccountByUser returns a user account
-func (s *FileStore) GetAccountByUser(_ context.Context, userID string) (*Account, error) {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- accountID, ok := s.UserID2AccountID[userID]
- if !ok {
- return nil, status.NewUserNotFoundError(userID)
- }
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return nil, err
- }
-
- return account.Copy(), nil
-}
-
-// GetAccountByPeerID returns an account for a given peer ID
-func (s *FileStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- accountID, ok := s.PeerID2AccountID[peerID]
- if !ok {
- return nil, status.Errorf(status.NotFound, "provided peer ID doesn't exists %s", peerID)
- }
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return nil, err
- }
-
- // this protection is needed because when we delete a peer, we don't really remove index peerID -> accountID.
- // check Account.Peers for a match
- if _, ok := account.Peers[peerID]; !ok {
- delete(s.PeerID2AccountID, peerID)
- log.WithContext(ctx).Warnf("removed stale peerID %s to accountID %s index", peerID, accountID)
- return nil, status.NewPeerNotFoundError(peerID)
- }
-
- return account.Copy(), nil
-}
-
-// GetAccountByPeerPubKey returns an account for a given peer WireGuard public key
-func (s *FileStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- accountID, ok := s.PeerKeyID2AccountID[peerKey]
- if !ok {
- return nil, status.NewPeerNotFoundError(peerKey)
- }
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return nil, err
- }
-
- // this protection is needed because when we delete a peer, we don't really remove index peerKey -> accountID.
- // check Account.Peers for a match
- stale := true
- for _, peer := range account.Peers {
- if peer.Key == peerKey {
- stale = false
- break
- }
- }
- if stale {
- delete(s.PeerKeyID2AccountID, peerKey)
- log.WithContext(ctx).Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID)
- return nil, status.NewPeerNotFoundError(peerKey)
- }
-
- return account.Copy(), nil
-}
-
-func (s *FileStore) GetAccountIDByPeerPubKey(_ context.Context, peerKey string) (string, error) {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- accountID, ok := s.PeerKeyID2AccountID[peerKey]
- if !ok {
- return "", status.NewPeerNotFoundError(peerKey)
- }
-
- return accountID, nil
-}
-
-func (s *FileStore) GetAccountIDByUserID(userID string) (string, error) {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- accountID, ok := s.UserID2AccountID[userID]
- if !ok {
- return "", status.NewUserNotFoundError(userID)
- }
-
- return accountID, nil
-}
-
-func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (string, error) {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
- if !ok {
- return "", status.NewSetupKeyNotFoundError()
- }
-
- return accountID, nil
-}
-
-func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- accountID, ok := s.PeerKeyID2AccountID[peerKey]
- if !ok {
- return nil, status.NewPeerNotFoundError(peerKey)
- }
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return nil, err
- }
-
- for _, peer := range account.Peers {
- if peer.Key == peerKey {
- return peer.Copy(), nil
- }
- }
-
- return nil, status.NewPeerNotFoundError(peerKey)
-}
-
-func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return nil, err
- }
-
- return account.Settings.Copy(), nil
-}
-
-// GetInstallationID returns the installation ID from the store
-func (s *FileStore) GetInstallationID() string {
- return s.InstallationID
-}
-
-// SaveInstallationID saves the installation ID
-func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- s.InstallationID = ID
-
- return s.persist(ctx, s.storeFile)
-}
-
-// SavePeer saves the peer in the account
-func (s *FileStore) SavePeer(_ context.Context, accountID string, peer *nbpeer.Peer) error {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return err
- }
-
- newPeer := peer.Copy()
-
- account.Peers[peer.ID] = newPeer
-
- s.PeerKeyID2AccountID[peer.Key] = accountID
- s.PeerID2AccountID[peer.ID] = accountID
-
- return nil
-}
-
-// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
-// PeerStatus will be saved eventually when some other changes occur.
-func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return err
- }
-
- peer := account.Peers[peerID]
- if peer == nil {
- return status.Errorf(status.NotFound, "peer %s not found", peerID)
- }
-
- peer.Status = &peerStatus
-
- return nil
-}
-
-// SavePeerLocation stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
-// Peer.Location will be saved eventually when some other changes occur.
-func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return err
- }
-
- peer := account.Peers[peerWithLocation.ID]
- if peer == nil {
- return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID)
- }
-
- peer.Location = peerWithLocation.Location
-
- return nil
-}
-
-// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things.
-func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return err
- }
-
- peer := account.Users[userID]
- if peer == nil {
- return status.Errorf(status.NotFound, "user %s not found", userID)
- }
-
- peer.LastLogin = lastLogin
-
- return nil
-}
-
-func (s *FileStore) GetPostureCheckByChecksDefinition(_ string, _ *posture.ChecksDefinition) (*posture.Checks, error) {
- return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented")
-}
-
// Close the FileStore persisting data to disk
func (s *FileStore) Close(ctx context.Context) error {
s.mux.Lock()
@@ -959,86 +257,3 @@ func (s *FileStore) Close(ctx context.Context) error {
func (s *FileStore) GetStoreEngine() StoreEngine {
return FileStoreEngine
}
-
-func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error {
- return status.Errorf(status.Internal, "SaveUsers is not implemented")
-}
-
-func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error {
- return status.Errorf(status.Internal, "SaveGroups is not implemented")
-}
-
-func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) {
- return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
-}
-
-func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- account, err := s.getAccount(accountID)
- if err != nil {
- return "", "", err
- }
-
- return account.Domain, account.DomainCategory, nil
-}
-
-// AccountExists checks whether an account exists by the given ID.
-func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) {
- _, exists := s.Accounts[id]
- return exists, nil
-}
-
-func (s *FileStore) GetAccountDNSSettings(_ context.Context, _ LockingStrength, _ string) (*DNSSettings, error) {
- return nil, status.Errorf(status.Internal, "GetAccountDNSSettings is not implemented")
-}
-
-func (s *FileStore) GetGroupByID(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
- return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented")
-}
-
-func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
- return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
-}
-
-func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) {
- return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
-}
-
-func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) {
- return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
-
-}
-
-func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) {
- return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
-}
-
-func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) {
- return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented")
-}
-
-func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) {
- return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented")
-}
-
-func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, _ string) (*route.Route, error) {
- return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented")
-}
-
-func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) {
- return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented")
-}
-
-func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*SetupKey, error) {
- return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented")
-}
-
-func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) {
- return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented")
-}
-
-func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) {
- return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented")
-}
diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go
deleted file mode 100644
index 56e46b696..000000000
--- a/management/server/file_store_test.go
+++ /dev/null
@@ -1,655 +0,0 @@
-package server
-
-import (
- "context"
- "crypto/sha256"
- "net"
- "path/filepath"
- "testing"
- "time"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-
- "github.com/netbirdio/netbird/management/server/group"
- nbpeer "github.com/netbirdio/netbird/management/server/peer"
- "github.com/netbirdio/netbird/util"
-)
-
-type accounts struct {
- Accounts map[string]*Account
-}
-
-func TestStalePeerIndices(t *testing.T) {
- storeDir := t.TempDir()
-
- err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
- if err != nil {
- t.Fatal(err)
- }
-
- store, err := NewFileStore(context.Background(), storeDir, nil)
- if err != nil {
- return
- }
-
- account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
- require.NoError(t, err)
-
- peerID := "some_peer"
- peerKey := "some_peer_key"
- account.Peers[peerID] = &nbpeer.Peer{
- ID: peerID,
- Key: peerKey,
- }
-
- err = store.SaveAccount(context.Background(), account)
- require.NoError(t, err)
-
- account.DeletePeer(peerID)
-
- err = store.SaveAccount(context.Background(), account)
- require.NoError(t, err)
-
- _, err = store.GetAccountByPeerID(context.Background(), peerID)
- require.Error(t, err, "expecting to get an error when found stale index")
-
- _, err = store.GetAccountByPeerPubKey(context.Background(), peerKey)
- require.Error(t, err, "expecting to get an error when found stale index")
-}
-
-func TestNewStore(t *testing.T) {
- store := newStore(t)
- defer store.Close(context.Background())
-
- if store.Accounts == nil || len(store.Accounts) != 0 {
- t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
- }
-
- if store.SetupKeyID2AccountID == nil || len(store.SetupKeyID2AccountID) != 0 {
- t.Errorf("expected to create a new empty SetupKeyID2AccountID map when creating a new FileStore")
- }
-
- if store.PeerKeyID2AccountID == nil || len(store.PeerKeyID2AccountID) != 0 {
- t.Errorf("expected to create a new empty PeerKeyID2AccountID map when creating a new FileStore")
- }
-
- if store.UserID2AccountID == nil || len(store.UserID2AccountID) != 0 {
- t.Errorf("expected to create a new empty UserID2AccountID map when creating a new FileStore")
- }
-
- if store.HashedPAT2TokenID == nil || len(store.HashedPAT2TokenID) != 0 {
- t.Errorf("expected to create a new empty HashedPAT2TokenID map when creating a new FileStore")
- }
-
- if store.TokenID2UserID == nil || len(store.TokenID2UserID) != 0 {
- t.Errorf("expected to create a new empty TokenID2UserID map when creating a new FileStore")
- }
-}
-
-func TestSaveAccount(t *testing.T) {
- store := newStore(t)
- defer store.Close(context.Background())
-
- account := newAccountWithId(context.Background(), "account_id", "testuser", "")
- setupKey := GenerateDefaultSetupKey()
- account.SetupKeys[setupKey.Key] = setupKey
- account.Peers["testpeer"] = &nbpeer.Peer{
- Key: "peerkey",
- SetupKey: "peerkeysetupkey",
- IP: net.IP{127, 0, 0, 1},
- Meta: nbpeer.PeerSystemMeta{},
- Name: "peer name",
- Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
- }
-
- // SaveAccount should trigger persist
- err := store.SaveAccount(context.Background(), account)
- if err != nil {
- return
- }
-
- if store.Accounts[account.Id] == nil {
- t.Errorf("expecting Account to be stored after SaveAccount()")
- }
-
- if store.PeerKeyID2AccountID["peerkey"] == "" {
- t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount()")
- }
-
- if store.UserID2AccountID["testuser"] == "" {
- t.Errorf("expecting UserID2AccountID index updated after SaveAccount()")
- }
-
- if store.SetupKeyID2AccountID[setupKey.Key] == "" {
- t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount()")
- }
-}
-
-func TestDeleteAccount(t *testing.T) {
- storeDir := t.TempDir()
- storeFile := filepath.Join(storeDir, "store.json")
- err := util.CopyFileContents("testdata/store.json", storeFile)
- if err != nil {
- t.Fatal(err)
- }
-
- store, err := NewFileStore(context.Background(), storeDir, nil)
- if err != nil {
- t.Fatal(err)
- }
- defer store.Close(context.Background())
-
- var account *Account
- for _, a := range store.Accounts {
- account = a
- break
- }
-
- require.NotNil(t, account, "failed to restore a FileStore file and get at least one account")
-
- err = store.DeleteAccount(context.Background(), account)
- require.NoError(t, err, "failed to delete account, error: %v", err)
-
- _, ok := store.Accounts[account.Id]
- require.False(t, ok, "failed to delete account")
-
- for id := range account.Users {
- _, ok := store.UserID2AccountID[id]
- assert.False(t, ok, "failed to delete UserID2AccountID index")
- for _, pat := range account.Users[id].PATs {
- _, ok := store.HashedPAT2TokenID[pat.HashedToken]
- assert.False(t, ok, "failed to delete HashedPAT2TokenID index")
- _, ok = store.TokenID2UserID[pat.ID]
- assert.False(t, ok, "failed to delete TokenID2UserID index")
- }
- }
-
- for _, p := range account.Peers {
- _, ok := store.PeerKeyID2AccountID[p.Key]
- assert.False(t, ok, "failed to delete PeerKeyID2AccountID index")
- _, ok = store.PeerID2AccountID[p.ID]
- assert.False(t, ok, "failed to delete PeerID2AccountID index")
- }
-
- for id := range account.SetupKeys {
- _, ok := store.SetupKeyID2AccountID[id]
- assert.False(t, ok, "failed to delete SetupKeyID2AccountID index")
- }
-
- _, ok = store.PrivateDomain2AccountID[account.Domain]
- assert.False(t, ok, "failed to delete PrivateDomain2AccountID index")
-
-}
-
-func TestStore(t *testing.T) {
- store := newStore(t)
- defer store.Close(context.Background())
-
- account := newAccountWithId(context.Background(), "account_id", "testuser", "")
- account.Peers["testpeer"] = &nbpeer.Peer{
- Key: "peerkey",
- SetupKey: "peerkeysetupkey",
- IP: net.IP{127, 0, 0, 1},
- Meta: nbpeer.PeerSystemMeta{},
- Name: "peer name",
- Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
- }
- account.Groups["all"] = &group.Group{
- ID: "all",
- Name: "all",
- Peers: []string{"testpeer"},
- }
- account.Policies = append(account.Policies, &Policy{
- ID: "all",
- Name: "all",
- Enabled: true,
- Rules: []*PolicyRule{
- {
- ID: "all",
- Name: "all",
- Sources: []string{"all"},
- Destinations: []string{"all"},
- },
- },
- })
- account.Policies = append(account.Policies, &Policy{
- ID: "dmz",
- Name: "dmz",
- Enabled: true,
- Rules: []*PolicyRule{
- {
- ID: "dmz",
- Name: "dmz",
- Enabled: true,
- Sources: []string{"all"},
- Destinations: []string{"all"},
- },
- },
- })
-
- // SaveAccount should trigger persist
- err := store.SaveAccount(context.Background(), account)
- if err != nil {
- return
- }
-
- restored, err := NewFileStore(context.Background(), store.storeFile, nil)
- if err != nil {
- return
- }
-
- restoredAccount := restored.Accounts[account.Id]
- if restoredAccount == nil {
- t.Errorf("failed to restore a FileStore file - missing Account %s", account.Id)
- return
- }
-
- if restoredAccount.Peers["testpeer"] == nil {
- t.Errorf("failed to restore a FileStore file - missing Peer testpeer")
- }
-
- if restoredAccount.CreatedBy != "testuser" {
- t.Errorf("failed to restore a FileStore file - missing Account CreatedBy")
- }
-
- if restoredAccount.Users["testuser"] == nil {
- t.Errorf("failed to restore a FileStore file - missing User testuser")
- }
-
- if restoredAccount.Network == nil {
- t.Errorf("failed to restore a FileStore file - missing Network")
- }
-
- if restoredAccount.Groups["all"] == nil {
- t.Errorf("failed to restore a FileStore file - missing Group all")
- }
-
- if len(restoredAccount.Policies) != 2 {
- t.Errorf("failed to restore a FileStore file - missing Policies")
- return
- }
-
- assert.Equal(t, account.Policies[0], restoredAccount.Policies[0], "failed to restore a FileStore file - missing Policy all")
- assert.Equal(t, account.Policies[1], restoredAccount.Policies[1], "failed to restore a FileStore file - missing Policy dmz")
-}
-
-func TestRestore(t *testing.T) {
- storeDir := t.TempDir()
-
- err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
- if err != nil {
- t.Fatal(err)
- }
-
- store, err := NewFileStore(context.Background(), storeDir, nil)
- if err != nil {
- return
- }
-
- account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
-
- require.NotNil(t, account, "failed to restore a FileStore file - missing account bf1c8084-ba50-4ce7-9439-34653001fc3b")
-
- require.NotNil(t, account.Users["edafee4e-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore file - missing Account User edafee4e-63fb-11ec-90d6-0242ac120003")
-
- require.NotNil(t, account.Users["f4f6d672-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore file - missing Account User f4f6d672-63fb-11ec-90d6-0242ac120003")
-
- require.NotNil(t, account.Network, "failed to restore a FileStore file - missing Account Network")
-
- require.NotNil(t, account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"], "failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
-
- require.NotNil(t, account.Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore wrong PATs length")
-
- require.Len(t, store.UserID2AccountID, 2, "failed to restore a FileStore wrong UserID2AccountID mapping length")
-
- require.Len(t, store.SetupKeyID2AccountID, 1, "failed to restore a FileStore wrong SetupKeyID2AccountID mapping length")
-
- require.Len(t, store.PrivateDomain2AccountID, 1, "failed to restore a FileStore wrong PrivateDomain2AccountID mapping length")
-
- require.Len(t, store.HashedPAT2TokenID, 1, "failed to restore a FileStore wrong HashedPAT2TokenID mapping length")
-
- require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length")
-}
-
-func TestRestoreGroups_Migration(t *testing.T) {
- storeDir := t.TempDir()
-
- err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
- if err != nil {
- t.Fatal(err)
- }
-
- store, err := NewFileStore(context.Background(), storeDir, nil)
- if err != nil {
- return
- }
-
- // create default group
- account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
- account.Groups = map[string]*group.Group{
- "cfefqs706sqkneg59g3g": {
- ID: "cfefqs706sqkneg59g3g",
- Name: "All",
- },
- }
- err = store.SaveAccount(context.Background(), account)
- require.NoError(t, err, "failed to save account")
-
- // restore account with default group with empty Issue field
- if store, err = NewFileStore(context.Background(), storeDir, nil); err != nil {
- return
- }
- account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
-
- require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups")
- require.Equal(t, group.GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark")
-}
-
-func TestGetAccountByPrivateDomain(t *testing.T) {
- storeDir := t.TempDir()
-
- err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
- if err != nil {
- t.Fatal(err)
- }
-
- store, err := NewFileStore(context.Background(), storeDir, nil)
- if err != nil {
- return
- }
-
- existingDomain := "test.com"
-
- account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain)
- require.NoError(t, err, "should found account")
- require.Equal(t, existingDomain, account.Domain, "domains should match")
-
- _, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com")
- require.Error(t, err, "should return error on domain lookup")
-}
-
-func TestFileStore_GetAccount(t *testing.T) {
- storeDir := t.TempDir()
- storeFile := filepath.Join(storeDir, "store.json")
- err := util.CopyFileContents("testdata/store.json", storeFile)
- if err != nil {
- t.Fatal(err)
- }
-
- accounts := &accounts{}
- _, err = util.ReadJson(storeFile, accounts)
- if err != nil {
- t.Fatal(err)
- }
-
- store, err := NewFileStore(context.Background(), storeDir, nil)
- if err != nil {
- t.Fatal(err)
- }
-
- expected := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
- if expected == nil {
- t.Fatalf("expected account doesn't exist")
- return
- }
-
- account, err := store.GetAccount(context.Background(), expected.Id)
- if err != nil {
- t.Fatal(err)
- }
-
- assert.Equal(t, expected.IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
- assert.Equal(t, expected.DomainCategory, account.DomainCategory)
- assert.Equal(t, expected.Domain, account.Domain)
- assert.Equal(t, expected.CreatedBy, account.CreatedBy)
- assert.Equal(t, expected.Network.Identifier, account.Network.Identifier)
- assert.Len(t, account.Peers, len(expected.Peers))
- assert.Len(t, account.Users, len(expected.Users))
- assert.Len(t, account.SetupKeys, len(expected.SetupKeys))
- assert.Len(t, account.Routes, len(expected.Routes))
- assert.Len(t, account.NameServerGroups, len(expected.NameServerGroups))
-}
-
-func TestFileStore_GetTokenIDByHashedToken(t *testing.T) {
- storeDir := t.TempDir()
- storeFile := filepath.Join(storeDir, "store.json")
- err := util.CopyFileContents("testdata/store.json", storeFile)
- if err != nil {
- t.Fatal(err)
- }
-
- accounts := &accounts{}
- _, err = util.ReadJson(storeFile, accounts)
- if err != nil {
- t.Fatal(err)
- }
-
- store, err := NewFileStore(context.Background(), storeDir, nil)
- if err != nil {
- t.Fatal(err)
- }
-
- hashedToken := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].HashedToken
- tokenID, err := store.GetTokenIDByHashedToken(context.Background(), hashedToken)
- if err != nil {
- t.Fatal(err)
- }
-
- expectedTokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID
- assert.Equal(t, expectedTokenID, tokenID)
-}
-
-func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) {
- store := newStore(t)
- defer store.Close(context.Background())
- store.HashedPAT2TokenID["someHashedToken"] = "someTokenId"
-
- err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken")
- if err != nil {
- t.Fatal(err)
- }
-
- assert.Empty(t, store.HashedPAT2TokenID["someHashedToken"])
-}
-
-func TestFileStore_DeleteTokenID2UserIDIndex(t *testing.T) {
- store := newStore(t)
- store.TokenID2UserID["someTokenId"] = "someUserId"
-
- err := store.DeleteTokenID2UserIDIndex("someTokenId")
- if err != nil {
- t.Fatal(err)
- }
-
- assert.Empty(t, store.TokenID2UserID["someTokenId"])
-}
-
-func TestFileStore_GetTokenIDByHashedToken_Failure(t *testing.T) {
- storeDir := t.TempDir()
- storeFile := filepath.Join(storeDir, "store.json")
- err := util.CopyFileContents("testdata/store.json", storeFile)
- if err != nil {
- t.Fatal(err)
- }
-
- accounts := &accounts{}
- _, err = util.ReadJson(storeFile, accounts)
- if err != nil {
- t.Fatal(err)
- }
-
- store, err := NewFileStore(context.Background(), storeDir, nil)
- if err != nil {
- t.Fatal(err)
- }
-
- wrongToken := sha256.Sum256([]byte("someNotValidTokenThatFails1234"))
- _, err = store.GetTokenIDByHashedToken(context.Background(), string(wrongToken[:]))
-
- assert.Error(t, err, "GetTokenIDByHashedToken should throw error if token invalid")
-}
-
-func TestFileStore_GetUserByTokenID(t *testing.T) {
- storeDir := t.TempDir()
- storeFile := filepath.Join(storeDir, "store.json")
- err := util.CopyFileContents("testdata/store.json", storeFile)
- if err != nil {
- t.Fatal(err)
- }
-
- accounts := &accounts{}
- _, err = util.ReadJson(storeFile, accounts)
- if err != nil {
- t.Fatal(err)
- }
-
- store, err := NewFileStore(context.Background(), storeDir, nil)
- if err != nil {
- t.Fatal(err)
- }
-
- tokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID
- user, err := store.GetUserByTokenID(context.Background(), tokenID)
- if err != nil {
- t.Fatal(err)
- }
-
- assert.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id)
-}
-
-func TestFileStore_GetUserByTokenID_Failure(t *testing.T) {
- storeDir := t.TempDir()
- storeFile := filepath.Join(storeDir, "store.json")
- err := util.CopyFileContents("testdata/store.json", storeFile)
- if err != nil {
- t.Fatal(err)
- }
-
- accounts := &accounts{}
- _, err = util.ReadJson(storeFile, accounts)
- if err != nil {
- t.Fatal(err)
- }
-
- store, err := NewFileStore(context.Background(), storeDir, nil)
- if err != nil {
- t.Fatal(err)
- }
-
- wrongTokenID := "someNonExistingTokenID"
- _, err = store.GetUserByTokenID(context.Background(), wrongTokenID)
-
- assert.Error(t, err, "GetUserByTokenID should throw error if tokenID invalid")
-}
-
-func TestFileStore_SavePeerStatus(t *testing.T) {
- storeDir := t.TempDir()
-
- err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
- if err != nil {
- t.Fatal(err)
- }
-
- store, err := NewFileStore(context.Background(), storeDir, nil)
- if err != nil {
- return
- }
-
- account, err := store.getAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
- if err != nil {
- t.Fatal(err)
- }
-
- // save status of non-existing peer
- newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}
- err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
- assert.Error(t, err)
-
- // save new status of existing peer
- account.Peers["testpeer"] = &nbpeer.Peer{
- Key: "peerkey",
- ID: "testpeer",
- SetupKey: "peerkeysetupkey",
- IP: net.IP{127, 0, 0, 1},
- Meta: nbpeer.PeerSystemMeta{},
- Name: "peer name",
- Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
- }
-
- err = store.SaveAccount(context.Background(), account)
- if err != nil {
- t.Fatal(err)
- }
-
- err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
- if err != nil {
- t.Fatal(err)
- }
- account, err = store.getAccount(account.Id)
- if err != nil {
- t.Fatal(err)
- }
-
- actual := account.Peers["testpeer"].Status
- assert.Equal(t, newStatus, *actual)
-}
-
-func TestFileStore_SavePeerLocation(t *testing.T) {
- storeDir := t.TempDir()
-
- err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
- if err != nil {
- t.Fatal(err)
- }
-
- store, err := NewFileStore(context.Background(), storeDir, nil)
- if err != nil {
- return
- }
- account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
- require.NoError(t, err)
-
- peer := &nbpeer.Peer{
- AccountID: account.Id,
- ID: "testpeer",
- Location: nbpeer.Location{
- ConnectionIP: net.ParseIP("10.0.0.0"),
- CountryCode: "YY",
- CityName: "City",
- GeoNameID: 1,
- },
- Meta: nbpeer.PeerSystemMeta{},
- }
- // error is expected as peer is not in store yet
- err = store.SavePeerLocation(account.Id, peer)
- assert.Error(t, err)
-
- account.Peers[peer.ID] = peer
- err = store.SaveAccount(context.Background(), account)
- require.NoError(t, err)
-
- peer.Location.ConnectionIP = net.ParseIP("35.1.1.1")
- peer.Location.CountryCode = "DE"
- peer.Location.CityName = "Berlin"
- peer.Location.GeoNameID = 2950159
-
- err = store.SavePeerLocation(account.Id, account.Peers[peer.ID])
- assert.NoError(t, err)
-
- account, err = store.GetAccount(context.Background(), account.Id)
- require.NoError(t, err)
-
- actual := account.Peers[peer.ID].Location
- assert.Equal(t, peer.Location, actual)
-}
-
-func newStore(t *testing.T) *FileStore {
- t.Helper()
- store, err := NewFileStore(context.Background(), t.TempDir(), nil)
- if err != nil {
- t.Errorf("failed creating a new store")
- }
-
- return store
-}
diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go
index ff09129bd..f8ab46d81 100644
--- a/management/server/management_proto_test.go
+++ b/management/server/management_proto_test.go
@@ -6,7 +6,6 @@ import (
"io"
"net"
"os"
- "path/filepath"
"runtime"
"sync"
"sync/atomic"
@@ -89,14 +88,7 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error
func Test_SyncProtocol(t *testing.T) {
dir := t.TempDir()
- err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json"))
- if err != nil {
- t.Fatal(err)
- }
- defer func() {
- os.Remove(filepath.Join(dir, "store.json")) //nolint
- }()
- mgmtServer, _, mgmtAddr, err := startManagementForTest(t, &Config{
+ mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
@@ -117,6 +109,7 @@ func Test_SyncProtocol(t *testing.T) {
Datadir: dir,
HttpConfig: nil,
})
+ defer cleanup()
if err != nil {
t.Fatal(err)
return
@@ -412,18 +405,18 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
}
}
-func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) {
+func startManagementForTest(t *testing.T, testFile string, config *Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
- return nil, nil, "", err
+ return nil, nil, "", nil, err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
- store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir)
+
+ store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), testFile)
if err != nil {
- return nil, nil, "", err
+ t.Fatal(err)
}
- t.Cleanup(cleanUp)
peersUpdateManager := NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
@@ -437,7 +430,8 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA
eventStore, nil, false, MocIntegratedValidator{}, metrics)
if err != nil {
- return nil, nil, "", err
+ cleanup()
+ return nil, nil, "", cleanup, err
}
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
@@ -445,7 +439,7 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA
ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, ephemeralMgr)
if err != nil {
- return nil, nil, "", err
+ return nil, nil, "", cleanup, err
}
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
@@ -455,7 +449,7 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA
}
}()
- return s, accountManager, lis.Addr().String(), nil
+ return s, accountManager, lis.Addr().String(), cleanup, nil
}
func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn, error) {
@@ -475,6 +469,7 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie
return mgmtProto.NewManagementServiceClient(conn), conn, nil
}
+
func Test_SyncStatusRace(t *testing.T) {
if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" {
t.Skip("Skipping on CI and Postgres store")
@@ -488,15 +483,8 @@ func Test_SyncStatusRace(t *testing.T) {
func testSyncStatusRace(t *testing.T) {
t.Helper()
dir := t.TempDir()
- err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json"))
- if err != nil {
- t.Fatal(err)
- }
- defer func() {
- os.Remove(filepath.Join(dir, "store.json")) //nolint
- }()
- mgmtServer, am, mgmtAddr, err := startManagementForTest(t, &Config{
+ mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
@@ -517,6 +505,7 @@ func testSyncStatusRace(t *testing.T) {
Datadir: dir,
HttpConfig: nil,
})
+ defer cleanup()
if err != nil {
t.Fatal(err)
return
@@ -665,15 +654,8 @@ func Test_LoginPerformance(t *testing.T) {
t.Run(bc.name, func(t *testing.T) {
t.Helper()
dir := t.TempDir()
- err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json"))
- if err != nil {
- t.Fatal(err)
- }
- defer func() {
- os.Remove(filepath.Join(dir, "store.json")) //nolint
- }()
- mgmtServer, am, _, err := startManagementForTest(t, &Config{
+ mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
@@ -694,6 +676,7 @@ func Test_LoginPerformance(t *testing.T) {
Datadir: dir,
HttpConfig: nil,
})
+ defer cleanup()
if err != nil {
t.Fatal(err)
return
diff --git a/management/server/management_test.go b/management/server/management_test.go
index 3956d96b1..ba27dc5e8 100644
--- a/management/server/management_test.go
+++ b/management/server/management_test.go
@@ -5,7 +5,6 @@ import (
"math/rand"
"net"
"os"
- "path/filepath"
"runtime"
sync2 "sync"
"time"
@@ -52,8 +51,6 @@ var _ = Describe("Management service", func() {
dataDir, err = os.MkdirTemp("", "wiretrustee_mgmt_test_tmp_*")
Expect(err).NotTo(HaveOccurred())
- err = util.CopyFileContents("testdata/store.json", filepath.Join(dataDir, "store.json"))
- Expect(err).NotTo(HaveOccurred())
var listener net.Listener
config := &server.Config{}
@@ -61,7 +58,7 @@ var _ = Describe("Management service", func() {
Expect(err).NotTo(HaveOccurred())
config.Datadir = dataDir
- s, listener = startServer(config)
+ s, listener = startServer(config, dataDir, "testdata/store.sqlite")
addr = listener.Addr().String()
client, conn = createRawClient(addr)
@@ -530,12 +527,12 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie
return mgmtProto.NewManagementServiceClient(conn), conn
}
-func startServer(config *server.Config) (*grpc.Server, net.Listener) {
+func startServer(config *server.Config, dataDir string, testFile string) (*grpc.Server, net.Listener) {
lis, err := net.Listen("tcp", ":0")
Expect(err).NotTo(HaveOccurred())
s := grpc.NewServer()
- store, _, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
+ store, _, err := server.NewTestStoreFromSqlite(context.Background(), testFile, dataDir)
if err != nil {
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
}
diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go
index 8dc8f6a4f..681bf533a 100644
--- a/management/server/mock_server/account_mock.go
+++ b/management/server/mock_server/account_mock.go
@@ -27,7 +27,8 @@ type MockAccountManager struct {
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
- GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error)
+ AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
+ GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
@@ -57,7 +58,7 @@ type MockAccountManager struct {
MarkPATUsedFunc func(ctx context.Context, pat string) error
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
- CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups,accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
+ CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
@@ -193,14 +194,22 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
}
-// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface
-func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) {
- if am.GetAccountIDByUserOrAccountIdFunc != nil {
- return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain)
+// AccountExists mock implementation of AccountExists from server.AccountManager interface
+func (am *MockAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
+ if am.AccountExistsFunc != nil {
+ return am.AccountExistsFunc(ctx, accountID)
+ }
+ return false, status.Errorf(codes.Unimplemented, "method AccountExists is not implemented")
+}
+
+// GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface
+func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) {
+ if am.GetAccountIDByUserIdFunc != nil {
+ return am.GetAccountIDByUserIdFunc(ctx, userId, domain)
}
return "", status.Errorf(
codes.Unimplemented,
- "method GetAccountIDByUserOrAccountID is not implemented",
+ "method GetAccountIDByUserID is not implemented",
)
}
@@ -435,7 +444,7 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
if am.CreateRouteFunc != nil {
- return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups,accessControlGroupID, enabled, userID, keepRoute)
+ return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
}
diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go
index 300f79c42..d9c359fba 100644
--- a/management/server/nameserver_test.go
+++ b/management/server/nameserver_test.go
@@ -775,7 +775,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
func createNSStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
- store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
+ store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
if err != nil {
return nil, err
}
diff --git a/management/server/peer.go b/management/server/peer.go
index ad4d2658a..a85e8c6b2 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -707,6 +707,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
updateRemotePeers := false
if login.UserID != "" {
+ if peer.UserID != login.UserID {
+ log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
+ return nil, nil, nil, status.Errorf(status.Unauthenticated, "invalid user")
+ }
+
changed, err := am.handleUserPeer(ctx, peer, settings)
if err != nil {
return nil, nil, nil, err
diff --git a/management/server/peer_test.go b/management/server/peer_test.go
index 0eb782ed0..08592de2e 100644
--- a/management/server/peer_test.go
+++ b/management/server/peer_test.go
@@ -1004,7 +1004,11 @@ func Test_RegisterPeerByUser(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
+ store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
eventStore := &activity.InMemoryEventStore{}
@@ -1065,7 +1069,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
+ store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
eventStore := &activity.InMemoryEventStore{}
@@ -1127,7 +1135,11 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
+ store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
eventStore := &activity.InMemoryEventStore{}
diff --git a/management/server/route_test.go b/management/server/route_test.go
index 74bf9c3ec..ca2e99b8a 100644
--- a/management/server/route_test.go
+++ b/management/server/route_test.go
@@ -1258,7 +1258,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
func createRouterStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
- store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
+ store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
if err != nil {
return nil, err
}
@@ -1738,7 +1738,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
}
assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)
- //peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
+ // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
assert.Len(t, routesFirewallRules, 2)
assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)
diff --git a/management/server/sql_store.go b/management/server/sql_store.go
index 85c68ef44..fe4dcafdb 100644
--- a/management/server/sql_store.go
+++ b/management/server/sql_store.go
@@ -10,6 +10,7 @@ import (
"path/filepath"
"runtime"
"runtime/debug"
+ "strconv"
"strings"
"sync"
"time"
@@ -63,8 +64,14 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr
if err != nil {
return nil, err
}
- conns := runtime.NumCPU()
- sql.SetMaxOpenConns(conns) // TODO: make it configurable
+
+ conns, err := strconv.Atoi(os.Getenv("NB_SQL_MAX_OPEN_CONNS"))
+ if err != nil {
+ conns = runtime.NumCPU()
+ }
+ sql.SetMaxOpenConns(conns)
+
+ log.Infof("Set max open db connections to %d", conns)
if err := migrate(ctx, db); err != nil {
return nil, fmt.Errorf("migrate: %w", err)
@@ -378,15 +385,26 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
Create(&usersToSave).Error
}
-// SaveGroups saves the given list of groups to the database.
-// It updates existing groups if a conflict occurs.
-func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
- groupsToSave := make([]nbgroup.Group, 0, len(groups))
- for _, group := range groups {
- group.AccountID = accountID
- groupsToSave = append(groupsToSave, *group)
+// SaveUser saves the given user to the database.
+func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
+ if result.Error != nil {
+ return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
}
- return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error
+ return nil
+}
+
+// SaveGroups saves the given list of groups to the database.
+func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
+ if len(groups) == 0 {
+ return nil
+ }
+
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
+ if result.Error != nil {
+ return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
+ }
+ return nil
}
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
@@ -420,7 +438,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
- return "", status.Errorf(status.Internal, "issue getting account from store")
+ return "", status.NewGetAccountFromStoreError(result.Error)
}
return accountID, nil
@@ -433,7 +451,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- return nil, status.NewSetupKeyNotFoundError()
+ return nil, status.NewSetupKeyNotFoundError(result.Error)
}
if key.AccountID == "" {
@@ -451,7 +469,7 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
- return "", status.Errorf(status.Internal, "issue getting account from store")
+ return "", status.NewGetAccountFromStoreError(result.Error)
}
return token.ID, nil
@@ -465,7 +483,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
- return nil, status.Errorf(status.Internal, "issue getting account from store")
+ return nil, status.NewGetAccountFromStoreError(result.Error)
}
if token.UserID == "" {
@@ -549,7 +567,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
- return nil, status.Errorf(status.Internal, "issue getting account from store")
+ return nil, status.NewGetAccountFromStoreError(result.Error)
}
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
@@ -612,7 +630,7 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- return nil, status.Errorf(status.Internal, "issue getting account from store")
+ return nil, status.NewGetAccountFromStoreError(result.Error)
}
if user.AccountID == "" {
@@ -629,7 +647,7 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- return nil, status.Errorf(status.Internal, "issue getting account from store")
+ return nil, status.NewGetAccountFromStoreError(result.Error)
}
if peer.AccountID == "" {
@@ -647,7 +665,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- return nil, status.Errorf(status.Internal, "issue getting account from store")
+ return nil, status.NewGetAccountFromStoreError(result.Error)
}
if peer.AccountID == "" {
@@ -665,7 +683,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- return "", status.Errorf(status.Internal, "issue getting account from store")
+ return "", status.NewGetAccountFromStoreError(result.Error)
}
return accountID, nil
@@ -678,7 +696,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- return "", status.Errorf(status.Internal, "issue getting account from store")
+ return "", status.NewGetAccountFromStoreError(result.Error)
}
return accountID, nil
@@ -691,7 +709,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- return "", status.NewSetupKeyNotFoundError()
+ return "", status.NewSetupKeyNotFoundError(result.Error)
}
if accountID == "" {
@@ -712,7 +730,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
- return nil, status.Errorf(status.Internal, "issue getting IPs from store")
+ return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error)
}
// Convert the JSON strings to net.IP objects
@@ -740,7 +758,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
- return nil, status.Errorf(status.Internal, "issue getting dns labels from store")
+ return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error)
}
return labels, nil
@@ -753,7 +771,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
- return nil, status.Errorf(status.Internal, "issue getting network from store")
+ return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
}
return accountNetwork.Network, nil
}
@@ -765,7 +783,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
- return nil, status.Errorf(status.Internal, "issue getting peer from store")
+ return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error)
}
return &peer, nil
@@ -777,7 +795,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found")
}
- return nil, status.Errorf(status.Internal, "issue getting settings from store")
+ return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err)
}
return accountSettings.Settings, nil
}
@@ -915,6 +933,28 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore,
return store, nil
}
+// NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB.
+func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
+ store, err := NewPostgresqlStore(ctx, dsn, metrics)
+ if err != nil {
+ return nil, err
+ }
+
+ err = store.SaveInstallationID(ctx, sqliteStore.GetInstallationID())
+ if err != nil {
+ return nil, err
+ }
+
+ for _, account := range sqliteStore.GetAllAccounts(ctx) {
+ err := store.SaveAccount(ctx, account)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return store, nil
+}
+
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
var setupKey SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
@@ -923,7 +963,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
- return nil, status.NewSetupKeyNotFoundError()
+ return nil, status.NewSetupKeyNotFoundError(result.Error)
}
return &setupKey, nil
}
@@ -955,7 +995,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
}
- return status.Errorf(status.Internal, "issue finding group 'All'")
+ return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
}
for _, existingPeerID := range group.Peers {
@@ -967,7 +1007,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
group.Peers = append(group.Peers, peerID)
if err := s.db.Save(&group).Error; err != nil {
- return status.Errorf(status.Internal, "issue updating group 'All'")
+ return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
}
return nil
@@ -981,7 +1021,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group not found for account")
}
- return status.Errorf(status.Internal, "issue finding group")
+ return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
}
for _, existingPeerID := range group.Peers {
@@ -993,15 +1033,20 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
group.Peers = append(group.Peers, peerId)
if err := s.db.Save(&group).Error; err != nil {
- return status.Errorf(status.Internal, "issue updating group")
+ return status.Errorf(status.Internal, "issue updating group: %s", err)
}
return nil
}
+// GetUserPeers retrieves peers for a user.
+func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
+ return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID)
+}
+
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
- return status.Errorf(status.Internal, "issue adding peer to account")
+ return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
}
return nil
@@ -1010,7 +1055,7 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
- return status.Errorf(status.Internal, "issue incrementing network serial count")
+ return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
}
return nil
}
@@ -1105,6 +1150,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
return &group, nil
}
+// SaveGroup saves a group to the store.
+func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
+ if result.Error != nil {
+ return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error)
+ }
+ return nil
+}
+
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go
index 64ef36831..4eed09c69 100644
--- a/management/server/sql_store_test.go
+++ b/management/server/sql_store_test.go
@@ -7,7 +7,6 @@ import (
"net"
"net/netip"
"os"
- "path/filepath"
"runtime"
"testing"
"time"
@@ -25,7 +24,6 @@ import (
"github.com/netbirdio/netbird/management/server/status"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
- "github.com/netbirdio/netbird/util"
)
func TestSqlite_NewStore(t *testing.T) {
@@ -347,7 +345,11 @@ func TestSqlite_GetAccount(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store := newSqliteStoreFromFile(t, "testdata/store.json")
+ store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
id := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -367,7 +369,11 @@ func TestSqlite_SavePeer(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store := newSqliteStoreFromFile(t, "testdata/store.json")
+ store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
@@ -415,7 +421,11 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store := newSqliteStoreFromFile(t, "testdata/store.json")
+ store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
+ defer cleanup()
+ if err != nil {
+ t.Fatal(err)
+ }
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
@@ -468,8 +478,11 @@ func TestSqlite_SavePeerLocation(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store := newSqliteStoreFromFile(t, "testdata/store.json")
-
+ store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
+ defer cleanup()
+ if err != nil {
+ t.Fatal(err)
+ }
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
@@ -519,8 +532,11 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store := newSqliteStoreFromFile(t, "testdata/store.json")
-
+ store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
+ defer cleanup()
+ if err != nil {
+ t.Fatal(err)
+ }
existingDomain := "test.com"
account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain)
@@ -539,8 +555,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store := newSqliteStoreFromFile(t, "testdata/store.json")
-
+ store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
+ defer cleanup()
+ if err != nil {
+ t.Fatal(err)
+ }
hashed := "SoMeHaShEdToKeN"
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
@@ -560,8 +579,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store := newSqliteStoreFromFile(t, "testdata/store.json")
-
+ store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
+ defer cleanup()
+ if err != nil {
+ t.Fatal(err)
+ }
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByTokenID(context.Background(), id)
@@ -668,24 +690,9 @@ func newSqliteStore(t *testing.T) *SqlStore {
t.Helper()
store, err := NewSqliteStore(context.Background(), t.TempDir(), nil)
- require.NoError(t, err)
- require.NotNil(t, store)
-
- return store
-}
-
-func newSqliteStoreFromFile(t *testing.T, filename string) *SqlStore {
- t.Helper()
-
- storeDir := t.TempDir()
-
- err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json"))
- require.NoError(t, err)
-
- fStore, err := NewFileStore(context.Background(), storeDir, nil)
- require.NoError(t, err)
-
- store, err := NewSqliteStoreFromFileStore(context.Background(), fStore, storeDir, nil)
+ t.Cleanup(func() {
+ store.Close(context.Background())
+ })
require.NoError(t, err)
require.NotNil(t, store)
@@ -733,32 +740,31 @@ func newPostgresqlStore(t *testing.T) *SqlStore {
return store
}
-func newPostgresqlStoreFromFile(t *testing.T, filename string) *SqlStore {
+func newPostgresqlStoreFromSqlite(t *testing.T, filename string) *SqlStore {
t.Helper()
- storeDir := t.TempDir()
- err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json"))
- require.NoError(t, err)
+ store, cleanUpQ, err := NewSqliteTestStore(context.Background(), t.TempDir(), filename)
+ t.Cleanup(cleanUpQ)
+ if err != nil {
+ return nil
+ }
- fStore, err := NewFileStore(context.Background(), storeDir, nil)
- require.NoError(t, err)
-
- cleanUp, err := testutil.CreatePGDB()
+ cleanUpP, err := testutil.CreatePGDB()
if err != nil {
t.Fatal(err)
}
- t.Cleanup(cleanUp)
+ t.Cleanup(cleanUpP)
postgresDsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok {
t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv)
}
- store, err := NewPostgresqlStoreFromFileStore(context.Background(), fStore, postgresDsn, nil)
+ pstore, err := NewPostgresqlStoreFromSqlStore(context.Background(), store, postgresDsn, nil)
require.NoError(t, err)
require.NotNil(t, store)
- return store
+ return pstore
}
func TestPostgresql_NewStore(t *testing.T) {
@@ -924,7 +930,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) {
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
}
- store := newPostgresqlStoreFromFile(t, "testdata/store.json")
+ store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
@@ -963,7 +969,7 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
}
- store := newPostgresqlStoreFromFile(t, "testdata/store.json")
+ store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
existingDomain := "test.com"
@@ -980,7 +986,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
}
- store := newPostgresqlStoreFromFile(t, "testdata/store.json")
+ store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
hashed := "SoMeHaShEdToKeN"
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
@@ -995,7 +1001,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
}
- store := newPostgresqlStoreFromFile(t, "testdata/store.json")
+ store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
@@ -1009,12 +1015,15 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
- defer store.Close(context.Background())
+ store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
+ defer cleanup()
+ if err != nil {
+ t.Fatal(err)
+ }
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
- _, err := store.GetAccount(context.Background(), existingAccountID)
+ _, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
@@ -1054,12 +1063,15 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
- defer store.Close(context.Background())
+ store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
+ if err != nil {
+ return
+ }
+ t.Cleanup(cleanup)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
- _, err := store.GetAccount(context.Background(), existingAccountID)
+ _, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
@@ -1096,12 +1108,15 @@ func TestSqlite_GetAccountNetwork(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
- defer store.Close(context.Background())
+ store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
+ t.Cleanup(cleanup)
+ if err != nil {
+ t.Fatal(err)
+ }
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
- _, err := store.GetAccount(context.Background(), existingAccountID)
+ _, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID)
@@ -1118,12 +1133,15 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
- defer store.Close(context.Background())
+ store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
+ t.Cleanup(cleanup)
+ if err != nil {
+ t.Fatal(err)
+ }
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
- _, err := store.GetAccount(context.Background(), existingAccountID)
+ _, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
@@ -1137,12 +1155,16 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
- store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
- defer store.Close(context.Background())
+
+ store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
+ t.Cleanup(cleanup)
+ if err != nil {
+ t.Fatal(err)
+ }
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
- _, err := store.GetAccount(context.Background(), existingAccountID)
+ _, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
@@ -1163,3 +1185,33 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, 2, setupKey.UsedTimes)
}
+
+func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
+ store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
+ t.Cleanup(cleanup)
+ if err != nil {
+ t.Fatal(err)
+ }
+ group := &nbgroup.Group{
+ ID: "group-id",
+ AccountID: "account-id",
+ Name: "group-name",
+ Issued: "api",
+ Peers: nil,
+ }
+ err = store.ExecuteInTransaction(context.Background(), func(transaction Store) error {
+ err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group)
+ if err != nil {
+ t.Fatal("failed to save group")
+ return err
+ }
+ group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
+ if err != nil {
+ t.Fatal("failed to get group")
+ return err
+ }
+ t.Logf("group: %v", group)
+ return nil
+ })
+ assert.NoError(t, err)
+}
diff --git a/management/server/status/error.go b/management/server/status/error.go
index d7fde35b9..29d185216 100644
--- a/management/server/status/error.go
+++ b/management/server/status/error.go
@@ -102,8 +102,12 @@ func NewPeerLoginExpiredError() error {
}
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
-func NewSetupKeyNotFoundError() error {
- return Errorf(NotFound, "setup key not found")
+func NewSetupKeyNotFoundError(err error) error {
+ return Errorf(NotFound, "setup key not found: %s", err)
+}
+
+func NewGetAccountFromStoreError(err error) error {
+ return Errorf(Internal, "issue getting account from store: %s", err)
}
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
diff --git a/management/server/store.go b/management/server/store.go
index f34a73c2d..50bc6afdf 100644
--- a/management/server/store.go
+++ b/management/server/store.go
@@ -12,10 +12,11 @@ import (
"strings"
"time"
- "github.com/netbirdio/netbird/dns"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
+ "github.com/netbirdio/netbird/dns"
+
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -59,6 +60,7 @@ type Store interface {
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
SaveUsers(accountID string, users map[string]*User) error
+ SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
@@ -67,7 +69,8 @@ type Store interface {
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
- SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
+ SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
+ SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
@@ -81,6 +84,7 @@ type Store interface {
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
+ GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
@@ -236,23 +240,29 @@ func getMigrations(ctx context.Context) []migrationFunc {
}
}
-// NewTestStoreFromJson is only used in tests
-func NewTestStoreFromJson(ctx context.Context, dataDir string) (Store, func(), error) {
- fstore, err := NewFileStore(ctx, dataDir, nil)
- if err != nil {
- return nil, nil, err
- }
-
+// NewTestStoreFromSqlite is only used in tests
+func NewTestStoreFromSqlite(ctx context.Context, filename string, dataDir string) (Store, func(), error) {
// if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE
kind := getStoreEngineFromEnv()
if kind == "" {
kind = SqliteStoreEngine
}
- var (
- store Store
- cleanUp func()
- )
+ var store *SqlStore
+ var err error
+ var cleanUp func()
+
+ if filename == "" {
+ store, err = NewSqliteStore(ctx, dataDir, nil)
+ cleanUp = func() {
+ store.Close(ctx)
+ }
+ } else {
+ store, cleanUp, err = NewSqliteTestStore(ctx, dataDir, filename)
+ }
+ if err != nil {
+ return nil, nil, err
+ }
if kind == PostgresStoreEngine {
cleanUp, err = testutil.CreatePGDB()
@@ -265,21 +275,32 @@ func NewTestStoreFromJson(ctx context.Context, dataDir string) (Store, func(), e
return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv)
}
- store, err = NewPostgresqlStoreFromFileStore(ctx, fstore, dsn, nil)
+ store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil)
if err != nil {
return nil, nil, err
}
- } else {
- store, err = NewSqliteStoreFromFileStore(ctx, fstore, dataDir, nil)
- if err != nil {
- return nil, nil, err
- }
- cleanUp = func() { store.Close(ctx) }
}
return store, cleanUp, nil
}
+func NewSqliteTestStore(ctx context.Context, dataDir string, testFile string) (*SqlStore, func(), error) {
+ err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db"))
+ if err != nil {
+ return nil, nil, err
+ }
+
+ store, err := NewSqliteStore(ctx, dataDir, nil)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return store, func() {
+ store.Close(ctx)
+ os.Remove(filepath.Join(dataDir, "store.db"))
+ }, nil
+}
+
// MigrateFileStoreToSqlite migrates the file store to the SQLite store.
func MigrateFileStoreToSqlite(ctx context.Context, dataDir string) error {
fileStorePath := path.Join(dataDir, storeFileName)
diff --git a/management/server/store_test.go b/management/server/store_test.go
index 40c36c9e0..fc821670d 100644
--- a/management/server/store_test.go
+++ b/management/server/store_test.go
@@ -14,12 +14,6 @@ type benchCase struct {
size int
}
-var newFs = func(b *testing.B) Store {
- b.Helper()
- store, _ := NewFileStore(context.Background(), b.TempDir(), nil)
- return store
-}
-
var newSqlite = func(b *testing.B) Store {
b.Helper()
store, _ := NewSqliteStore(context.Background(), b.TempDir(), nil)
@@ -28,13 +22,9 @@ var newSqlite = func(b *testing.B) Store {
func BenchmarkTest_StoreWrite(b *testing.B) {
cases := []benchCase{
- {name: "FileStore_Write", storeFn: newFs, size: 100},
{name: "SqliteStore_Write", storeFn: newSqlite, size: 100},
- {name: "FileStore_Write", storeFn: newFs, size: 500},
{name: "SqliteStore_Write", storeFn: newSqlite, size: 500},
- {name: "FileStore_Write", storeFn: newFs, size: 1000},
{name: "SqliteStore_Write", storeFn: newSqlite, size: 1000},
- {name: "FileStore_Write", storeFn: newFs, size: 2000},
{name: "SqliteStore_Write", storeFn: newSqlite, size: 2000},
}
@@ -61,11 +51,8 @@ func BenchmarkTest_StoreWrite(b *testing.B) {
func BenchmarkTest_StoreRead(b *testing.B) {
cases := []benchCase{
- {name: "FileStore_Read", storeFn: newFs, size: 100},
{name: "SqliteStore_Read", storeFn: newSqlite, size: 100},
- {name: "FileStore_Read", storeFn: newFs, size: 500},
{name: "SqliteStore_Read", storeFn: newSqlite, size: 500},
- {name: "FileStore_Read", storeFn: newFs, size: 1000},
{name: "SqliteStore_Read", storeFn: newSqlite, size: 1000},
}
@@ -89,3 +76,11 @@ func BenchmarkTest_StoreRead(b *testing.B) {
})
}
}
+
+func newStore(t *testing.T) Store {
+ t.Helper()
+
+ store := newSqliteStore(t)
+
+ return store
+}
diff --git a/management/server/testdata/extended-store.json b/management/server/testdata/extended-store.json
deleted file mode 100644
index 7f96e57a8..000000000
--- a/management/server/testdata/extended-store.json
+++ /dev/null
@@ -1,120 +0,0 @@
-{
- "Accounts": {
- "bf1c8084-ba50-4ce7-9439-34653001fc3b": {
- "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
- "CreatedBy": "",
- "Domain": "test.com",
- "DomainCategory": "private",
- "IsDomainPrimaryAccount": true,
- "SetupKeys": {
- "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
- "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
- "AccountID": "",
- "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
- "Name": "Default key",
- "Type": "reusable",
- "CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
- "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
- "UpdatedAt": "0001-01-01T00:00:00Z",
- "Revoked": false,
- "UsedTimes": 0,
- "LastUsed": "0001-01-01T00:00:00Z",
- "AutoGroups": ["cfefqs706sqkneg59g2g"],
- "UsageLimit": 0,
- "Ephemeral": false
- },
- "A2C8E62B-38F5-4553-B31E-DD66C696CEBC": {
- "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
- "AccountID": "",
- "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
- "Name": "Faulty key with non existing group",
- "Type": "reusable",
- "CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
- "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
- "UpdatedAt": "0001-01-01T00:00:00Z",
- "Revoked": false,
- "UsedTimes": 0,
- "LastUsed": "0001-01-01T00:00:00Z",
- "AutoGroups": ["abcd"],
- "UsageLimit": 0,
- "Ephemeral": false
- }
- },
- "Network": {
- "id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
- "Net": {
- "IP": "100.64.0.0",
- "Mask": "//8AAA=="
- },
- "Dns": "",
- "Serial": 0
- },
- "Peers": {},
- "Users": {
- "edafee4e-63fb-11ec-90d6-0242ac120003": {
- "Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
- "AccountID": "",
- "Role": "admin",
- "IsServiceUser": false,
- "ServiceUserName": "",
- "AutoGroups": ["cfefqs706sqkneg59g3g"],
- "PATs": {},
- "Blocked": false,
- "LastLogin": "0001-01-01T00:00:00Z"
- },
- "f4f6d672-63fb-11ec-90d6-0242ac120003": {
- "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
- "AccountID": "",
- "Role": "user",
- "IsServiceUser": false,
- "ServiceUserName": "",
- "AutoGroups": null,
- "PATs": {
- "9dj38s35-63fb-11ec-90d6-0242ac120003": {
- "ID": "9dj38s35-63fb-11ec-90d6-0242ac120003",
- "UserID": "",
- "Name": "",
- "HashedToken": "SoMeHaShEdToKeN",
- "ExpirationDate": "2023-02-27T00:00:00Z",
- "CreatedBy": "user",
- "CreatedAt": "2023-01-01T00:00:00Z",
- "LastUsed": "2023-02-01T00:00:00Z"
- }
- },
- "Blocked": false,
- "LastLogin": "0001-01-01T00:00:00Z"
- }
- },
- "Groups": {
- "cfefqs706sqkneg59g4g": {
- "ID": "cfefqs706sqkneg59g4g",
- "Name": "All",
- "Peers": []
- },
- "cfefqs706sqkneg59g3g": {
- "ID": "cfefqs706sqkneg59g3g",
- "Name": "AwesomeGroup1",
- "Peers": []
- },
- "cfefqs706sqkneg59g2g": {
- "ID": "cfefqs706sqkneg59g2g",
- "Name": "AwesomeGroup2",
- "Peers": []
- }
- },
- "Rules": null,
- "Policies": [],
- "Routes": null,
- "NameServerGroups": null,
- "DNSSettings": null,
- "Settings": {
- "PeerLoginExpirationEnabled": false,
- "PeerLoginExpiration": 86400000000000,
- "GroupsPropagationEnabled": false,
- "JWTGroupsEnabled": false,
- "JWTGroupsClaimName": ""
- }
- }
- },
- "InstallationID": ""
-}
diff --git a/management/server/testdata/extended-store.sqlite b/management/server/testdata/extended-store.sqlite
new file mode 100644
index 000000000..81aea8118
Binary files /dev/null and b/management/server/testdata/extended-store.sqlite differ
diff --git a/management/server/testdata/store.json b/management/server/testdata/store.json
deleted file mode 100644
index 6a8fc0712..000000000
--- a/management/server/testdata/store.json
+++ /dev/null
@@ -1,120 +0,0 @@
-{
- "Accounts": {
- "bf1c8084-ba50-4ce7-9439-34653001fc3b": {
- "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
- "CreatedBy": "",
- "Domain": "test.com",
- "DomainCategory": "private",
- "IsDomainPrimaryAccount": true,
- "SetupKeys": {
- "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
- "Id": "",
- "AccountID": "",
- "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
- "Name": "Default key",
- "Type": "reusable",
- "CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
- "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
- "UpdatedAt": "0001-01-01T00:00:00Z",
- "Revoked": false,
- "UsedTimes": 0,
- "LastUsed": "0001-01-01T00:00:00Z",
- "AutoGroups": ["cq9bbkjjuspi5gd38epg"],
- "UsageLimit": 0,
- "Ephemeral": false
- }
- },
- "Network": {
- "id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
- "Net": {
- "IP": "100.64.0.0",
- "Mask": "//8AAA=="
- },
- "Dns": "",
- "Serial": 0
- },
- "Peers": {},
- "Users": {
- "edafee4e-63fb-11ec-90d6-0242ac120003": {
- "Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
- "AccountID": "",
- "Role": "admin",
- "IsServiceUser": false,
- "ServiceUserName": "",
- "AutoGroups": null,
- "PATs": {},
- "Blocked": false,
- "LastLogin": "0001-01-01T00:00:00Z"
- },
- "f4f6d672-63fb-11ec-90d6-0242ac120003": {
- "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
- "AccountID": "",
- "Role": "user",
- "IsServiceUser": false,
- "ServiceUserName": "",
- "AutoGroups": null,
- "PATs": {
- "9dj38s35-63fb-11ec-90d6-0242ac120003": {
- "ID": "9dj38s35-63fb-11ec-90d6-0242ac120003",
- "UserID": "",
- "Name": "",
- "HashedToken": "SoMeHaShEdToKeN",
- "ExpirationDate": "2023-02-27T00:00:00Z",
- "CreatedBy": "user",
- "CreatedAt": "2023-01-01T00:00:00Z",
- "LastUsed": "2023-02-01T00:00:00Z"
- }
- },
- "Blocked": false,
- "LastLogin": "0001-01-01T00:00:00Z"
- }
- },
- "Groups": {
- "cq9bbkjjuspi5gd38epg": {
- "ID": "cq9bbkjjuspi5gd38epg",
- "Name": "All",
- "Peers": []
- }
- },
- "Rules": null,
- "Policies": [
- {
- "ID": "cq9bbkjjuspi5gd38eq0",
- "Name": "Default",
- "Description": "This is a default rule that allows connections between all the resources",
- "Enabled": true,
- "Rules": [
- {
- "ID": "cq9bbkjjuspi5gd38eq0",
- "Name": "Default",
- "Description": "This is a default rule that allows connections between all the resources",
- "Enabled": true,
- "Action": "accept",
- "Destinations": [
- "cq9bbkjjuspi5gd38epg"
- ],
- "Sources": [
- "cq9bbkjjuspi5gd38epg"
- ],
- "Bidirectional": true,
- "Protocol": "all",
- "Ports": null
- }
- ],
- "SourcePostureChecks": null
- }
- ],
- "Routes": null,
- "NameServerGroups": null,
- "DNSSettings": null,
- "Settings": {
- "PeerLoginExpirationEnabled": false,
- "PeerLoginExpiration": 86400000000000,
- "GroupsPropagationEnabled": false,
- "JWTGroupsEnabled": false,
- "JWTGroupsClaimName": ""
- }
- }
- },
- "InstallationID": ""
-}
\ No newline at end of file
diff --git a/management/server/testdata/store.sqlite b/management/server/testdata/store.sqlite
new file mode 100644
index 000000000..5fc746285
Binary files /dev/null and b/management/server/testdata/store.sqlite differ
diff --git a/management/server/testdata/store_policy_migrate.json b/management/server/testdata/store_policy_migrate.json
deleted file mode 100644
index 1b046e632..000000000
--- a/management/server/testdata/store_policy_migrate.json
+++ /dev/null
@@ -1,116 +0,0 @@
-{
- "Accounts": {
- "bf1c8084-ba50-4ce7-9439-34653001fc3b": {
- "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
- "Domain": "test.com",
- "DomainCategory": "private",
- "IsDomainPrimaryAccount": true,
- "SetupKeys": {
- "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
- "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
- "Name": "Default key",
- "Type": "reusable",
- "CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
- "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
- "Revoked": false,
- "UsedTimes": 0
- }
- },
- "Network": {
- "Id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
- "Net": {
- "IP": "100.64.0.0",
- "Mask": "//8AAA=="
- },
- "Dns": null
- },
- "Peers": {
- "cfefqs706sqkneg59g4g": {
- "ID": "cfefqs706sqkneg59g4g",
- "Key": "MI5mHfJhbggPfD3FqEIsXm8X5bSWeUI2LhO9MpEEtWA=",
- "SetupKey": "",
- "IP": "100.103.179.238",
- "Meta": {
- "Hostname": "Ubuntu-2204-jammy-amd64-base",
- "GoOS": "linux",
- "Kernel": "Linux",
- "Core": "22.04",
- "Platform": "x86_64",
- "OS": "Ubuntu",
- "WtVersion": "development",
- "UIVersion": ""
- },
- "Name": "crocodile",
- "DNSLabel": "crocodile",
- "Status": {
- "LastSeen": "2023-02-13T12:37:12.635454796Z",
- "Connected": true
- },
- "UserID": "edafee4e-63fb-11ec-90d6-0242ac120003",
- "SSHKey": "AAAAC3NzaC1lZDI1NTE5AAAAIJN1NM4bpB9K",
- "SSHEnabled": false
- },
- "cfeg6sf06sqkneg59g50": {
- "ID": "cfeg6sf06sqkneg59g50",
- "Key": "zMAOKUeIYIuun4n0xPR1b3IdYZPmsyjYmB2jWCuloC4=",
- "SetupKey": "",
- "IP": "100.103.26.180",
- "Meta": {
- "Hostname": "borg",
- "GoOS": "linux",
- "Kernel": "Linux",
- "Core": "22.04",
- "Platform": "x86_64",
- "OS": "Ubuntu",
- "WtVersion": "development",
- "UIVersion": ""
- },
- "Name": "dingo",
- "DNSLabel": "dingo",
- "Status": {
- "LastSeen": "2023-02-21T09:37:42.565899199Z",
- "Connected": false
- },
- "UserID": "f4f6d672-63fb-11ec-90d6-0242ac120003",
- "SSHKey": "AAAAC3NzaC1lZDI1NTE5AAAAILHW",
- "SSHEnabled": true
- }
- },
- "Groups": {
- "cfefqs706sqkneg59g3g": {
- "ID": "cfefqs706sqkneg59g3g",
- "Name": "All",
- "Peers": [
- "cfefqs706sqkneg59g4g",
- "cfeg6sf06sqkneg59g50"
- ]
- }
- },
- "Rules": {
- "cfefqs706sqkneg59g40": {
- "ID": "cfefqs706sqkneg59g40",
- "Name": "Default",
- "Description": "This is a default rule that allows connections between all the resources",
- "Disabled": false,
- "Source": [
- "cfefqs706sqkneg59g3g"
- ],
- "Destination": [
- "cfefqs706sqkneg59g3g"
- ],
- "Flow": 0
- }
- },
- "Users": {
- "edafee4e-63fb-11ec-90d6-0242ac120003": {
- "Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
- "Role": "admin"
- },
- "f4f6d672-63fb-11ec-90d6-0242ac120003": {
- "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
- "Role": "user"
- }
- }
- }
- }
-}
diff --git a/management/server/testdata/store_policy_migrate.sqlite b/management/server/testdata/store_policy_migrate.sqlite
new file mode 100644
index 000000000..0c1a491a6
Binary files /dev/null and b/management/server/testdata/store_policy_migrate.sqlite differ
diff --git a/management/server/testdata/store_with_expired_peers.json b/management/server/testdata/store_with_expired_peers.json
deleted file mode 100644
index 44c225682..000000000
--- a/management/server/testdata/store_with_expired_peers.json
+++ /dev/null
@@ -1,130 +0,0 @@
-{
- "Accounts": {
- "bf1c8084-ba50-4ce7-9439-34653001fc3b": {
- "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
- "Domain": "test.com",
- "DomainCategory": "private",
- "IsDomainPrimaryAccount": true,
- "Settings": {
- "PeerLoginExpirationEnabled": true,
- "PeerLoginExpiration": 3600000000000
- },
- "SetupKeys": {
- "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
- "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
- "Name": "Default key",
- "Type": "reusable",
- "CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
- "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
- "Revoked": false,
- "UsedTimes": 0
-
- }
- },
- "Network": {
- "Id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
- "Net": {
- "IP": "100.64.0.0",
- "Mask": "//8AAA=="
- },
- "Dns": null
- },
- "Peers": {
- "cfvprsrlo1hqoo49ohog": {
- "ID": "cfvprsrlo1hqoo49ohog",
- "Key": "5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=",
- "SetupKey": "72546A29-6BC8-4311-BCFC-9CDBF33F1A48",
- "IP": "100.64.114.31",
- "Meta": {
- "Hostname": "f2a34f6a4731",
- "GoOS": "linux",
- "Kernel": "Linux",
- "Core": "11",
- "Platform": "unknown",
- "OS": "Debian GNU/Linux",
- "WtVersion": "0.12.0",
- "UIVersion": ""
- },
- "Name": "f2a34f6a4731",
- "DNSLabel": "f2a34f6a4731",
- "Status": {
- "LastSeen": "2023-03-02T09:21:02.189035775+01:00",
- "Connected": false,
- "LoginExpired": false
- },
- "UserID": "",
- "SSHKey": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk",
- "SSHEnabled": false,
- "LoginExpirationEnabled": true,
- "LastLogin": "2023-03-01T19:48:19.817799698+01:00"
- },
- "cg05lnblo1hkg2j514p0": {
- "ID": "cg05lnblo1hkg2j514p0",
- "Key": "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=",
- "SetupKey": "",
- "IP": "100.64.39.54",
- "Meta": {
- "Hostname": "expiredhost",
- "GoOS": "linux",
- "Kernel": "Linux",
- "Core": "22.04",
- "Platform": "x86_64",
- "OS": "Ubuntu",
- "WtVersion": "development",
- "UIVersion": ""
- },
- "Name": "expiredhost",
- "DNSLabel": "expiredhost",
- "Status": {
- "LastSeen": "2023-03-02T09:19:57.276717255+01:00",
- "Connected": false,
- "LoginExpired": true
- },
- "UserID": "edafee4e-63fb-11ec-90d6-0242ac120003",
- "SSHKey": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK",
- "SSHEnabled": false,
- "LoginExpirationEnabled": true,
- "LastLogin": "2023-03-02T09:14:21.791679181+01:00"
- },
- "cg3161rlo1hs9cq94gdg": {
- "ID": "cg3161rlo1hs9cq94gdg",
- "Key": "mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=",
- "SetupKey": "",
- "IP": "100.64.117.96",
- "Meta": {
- "Hostname": "testhost",
- "GoOS": "linux",
- "Kernel": "Linux",
- "Core": "22.04",
- "Platform": "x86_64",
- "OS": "Ubuntu",
- "WtVersion": "development",
- "UIVersion": ""
- },
- "Name": "testhost",
- "DNSLabel": "testhost",
- "Status": {
- "LastSeen": "2023-03-06T18:21:27.252010027+01:00",
- "Connected": false,
- "LoginExpired": false
- },
- "UserID": "edafee4e-63fb-11ec-90d6-0242ac120003",
- "SSHKey": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM",
- "SSHEnabled": false,
- "LoginExpirationEnabled": false,
- "LastLogin": "2023-03-07T09:02:47.442857106+01:00"
- }
- },
- "Users": {
- "edafee4e-63fb-11ec-90d6-0242ac120003": {
- "Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
- "Role": "admin"
- },
- "f4f6d672-63fb-11ec-90d6-0242ac120003": {
- "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
- "Role": "user"
- }
- }
- }
- }
-}
\ No newline at end of file
diff --git a/management/server/testdata/store_with_expired_peers.sqlite b/management/server/testdata/store_with_expired_peers.sqlite
new file mode 100644
index 000000000..ed1133211
Binary files /dev/null and b/management/server/testdata/store_with_expired_peers.sqlite differ
diff --git a/management/server/testdata/storev1.json b/management/server/testdata/storev1.json
deleted file mode 100644
index 674b2b87a..000000000
--- a/management/server/testdata/storev1.json
+++ /dev/null
@@ -1,154 +0,0 @@
-{
- "Accounts": {
- "auth0|61bf82ddeab084006aa1bccd": {
- "Id": "auth0|61bf82ddeab084006aa1bccd",
- "SetupKeys": {
- "1B2B50B0-B3E8-4B0C-A426-525EDB8481BD": {
- "Id": "831727121",
- "Key": "1B2B50B0-B3E8-4B0C-A426-525EDB8481BD",
- "Name": "One-off key",
- "Type": "one-off",
- "CreatedAt": "2021-12-24T16:09:45.926075752+01:00",
- "ExpiresAt": "2022-01-23T16:09:45.926075752+01:00",
- "Revoked": false,
- "UsedTimes": 1,
- "LastUsed": "2021-12-24T16:12:45.763424077+01:00"
- },
- "EB51E9EB-A11F-4F6E-8E49-C982891B405A": {
- "Id": "1769568301",
- "Key": "EB51E9EB-A11F-4F6E-8E49-C982891B405A",
- "Name": "Default key",
- "Type": "reusable",
- "CreatedAt": "2021-12-24T16:09:45.926073628+01:00",
- "ExpiresAt": "2022-01-23T16:09:45.926073628+01:00",
- "Revoked": false,
- "UsedTimes": 1,
- "LastUsed": "2021-12-24T16:13:06.236748538+01:00"
- }
- },
- "Network": {
- "Id": "a443c07a-5765-4a78-97fc-390d9c1d0e49",
- "Net": {
- "IP": "100.64.0.0",
- "Mask": "/8AAAA=="
- },
- "Dns": ""
- },
- "Peers": {
- "oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=": {
- "Key": "oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=",
- "SetupKey": "EB51E9EB-A11F-4F6E-8E49-C982891B405A",
- "IP": "100.64.0.2",
- "Meta": {
- "Hostname": "braginini",
- "GoOS": "linux",
- "Kernel": "Linux",
- "Core": "21.04",
- "Platform": "x86_64",
- "OS": "Ubuntu",
- "WtVersion": ""
- },
- "Name": "braginini",
- "Status": {
- "LastSeen": "2021-12-24T16:13:11.244342541+01:00",
- "Connected": false
- }
- },
- "xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=": {
- "Key": "xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=",
- "SetupKey": "1B2B50B0-B3E8-4B0C-A426-525EDB8481BD",
- "IP": "100.64.0.1",
- "Meta": {
- "Hostname": "braginini",
- "GoOS": "linux",
- "Kernel": "Linux",
- "Core": "21.04",
- "Platform": "x86_64",
- "OS": "Ubuntu",
- "WtVersion": ""
- },
- "Name": "braginini",
- "Status": {
- "LastSeen": "2021-12-24T16:12:49.089339333+01:00",
- "Connected": false
- }
- }
- }
- },
- "google-oauth2|103201118415301331038": {
- "Id": "google-oauth2|103201118415301331038",
- "SetupKeys": {
- "5AFB60DB-61F2-4251-8E11-494847EE88E9": {
- "Id": "2485964613",
- "Key": "5AFB60DB-61F2-4251-8E11-494847EE88E9",
- "Name": "Default key",
- "Type": "reusable",
- "CreatedAt": "2021-12-24T16:10:02.238476+01:00",
- "ExpiresAt": "2022-01-23T16:10:02.238476+01:00",
- "Revoked": false,
- "UsedTimes": 1,
- "LastUsed": "2021-12-24T16:12:05.994307717+01:00"
- },
- "A72E4DC2-00DE-4542-8A24-62945438104E": {
- "Id": "3504804807",
- "Key": "A72E4DC2-00DE-4542-8A24-62945438104E",
- "Name": "One-off key",
- "Type": "one-off",
- "CreatedAt": "2021-12-24T16:10:02.238478209+01:00",
- "ExpiresAt": "2022-01-23T16:10:02.238478209+01:00",
- "Revoked": false,
- "UsedTimes": 1,
- "LastUsed": "2021-12-24T16:11:27.015741738+01:00"
- }
- },
- "Network": {
- "Id": "b6d0b152-364e-40c1-a8a1-fa7bcac2267f",
- "Net": {
- "IP": "100.64.0.0",
- "Mask": "/8AAAA=="
- },
- "Dns": ""
- },
- "Peers": {
- "6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=": {
- "Key": "6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=",
- "SetupKey": "5AFB60DB-61F2-4251-8E11-494847EE88E9",
- "IP": "100.64.0.2",
- "Meta": {
- "Hostname": "braginini",
- "GoOS": "linux",
- "Kernel": "Linux",
- "Core": "21.04",
- "Platform": "x86_64",
- "OS": "Ubuntu",
- "WtVersion": ""
- },
- "Name": "braginini",
- "Status": {
- "LastSeen": "2021-12-24T16:12:05.994305438+01:00",
- "Connected": false
- }
- },
- "Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=": {
- "Key": "Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=",
- "SetupKey": "A72E4DC2-00DE-4542-8A24-62945438104E",
- "IP": "100.64.0.1",
- "Meta": {
- "Hostname": "braginini",
- "GoOS": "linux",
- "Kernel": "Linux",
- "Core": "21.04",
- "Platform": "x86_64",
- "OS": "Ubuntu",
- "WtVersion": ""
- },
- "Name": "braginini",
- "Status": {
- "LastSeen": "2021-12-24T16:11:27.015739803+01:00",
- "Connected": false
- }
- }
- }
- }
- }
-}
\ No newline at end of file
diff --git a/management/server/testdata/storev1.sqlite b/management/server/testdata/storev1.sqlite
new file mode 100644
index 000000000..9a376698e
Binary files /dev/null and b/management/server/testdata/storev1.sqlite differ
diff --git a/management/server/user.go b/management/server/user.go
index 7acb0b487..e40fc67eb 100644
--- a/management/server/user.go
+++ b/management/server/user.go
@@ -9,14 +9,14 @@ import (
"time"
"github.com/google/uuid"
- log "github.com/sirupsen/logrus"
-
"github.com/netbirdio/netbird/management/server/activity"
+ nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
+ log "github.com/sirupsen/logrus"
)
const (
@@ -1274,6 +1274,74 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, hadPeers, nil
}
+// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
+func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd,
+ groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) {
+
+ if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
+ return
+ }
+
+ userPeerIDMap := make(map[string]struct{}, len(peers))
+ for _, peer := range peers {
+ userPeerIDMap[peer.ID] = struct{}{}
+ }
+
+ for _, gid := range groupsToAdd {
+ group, ok := accountGroups[gid]
+ if !ok {
+ return nil, errors.New("group not found")
+ }
+ addUserPeersToGroup(userPeerIDMap, group)
+ groupsToUpdate = append(groupsToUpdate, group)
+ }
+
+ for _, gid := range groupsToRemove {
+ group, ok := accountGroups[gid]
+ if !ok {
+ return nil, errors.New("group not found")
+ }
+ removeUserPeersFromGroup(userPeerIDMap, group)
+ groupsToUpdate = append(groupsToUpdate, group)
+ }
+
+ return groupsToUpdate, nil
+}
+
+// addUserPeersToGroup adds the user's peers to the group.
+func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) {
+ groupPeers := make(map[string]struct{}, len(group.Peers))
+ for _, pid := range group.Peers {
+ groupPeers[pid] = struct{}{}
+ }
+
+ for pid := range userPeerIDs {
+ groupPeers[pid] = struct{}{}
+ }
+
+ group.Peers = make([]string, 0, len(groupPeers))
+ for pid := range groupPeers {
+ group.Peers = append(group.Peers, pid)
+ }
+}
+
+// removeUserPeersFromGroup removes user's peers from the group.
+func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) {
+ // skip removing peers from group All
+ if group.Name == "All" {
+ return
+ }
+
+ updatedPeers := make([]string, 0, len(group.Peers))
+ for _, pid := range group.Peers {
+ if _, found := userPeerIDs[pid]; !found {
+ updatedPeers = append(updatedPeers, pid)
+ }
+ }
+
+ group.Peers = updatedPeers
+}
+
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
for _, user := range userData {
if user.ID == userID {
diff --git a/management/server/user_test.go b/management/server/user_test.go
index c836ac98f..3f7f814a0 100644
--- a/management/server/user_test.go
+++ b/management/server/user_test.go
@@ -62,8 +62,10 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
assert.Equal(t, pat.CreatedBy, mockUserID)
- fileStore := am.Store.(*FileStore)
- tokenID := fileStore.HashedPAT2TokenID[pat.HashedToken]
+ tokenID, err := am.Store.GetTokenIDByHashedToken(context.Background(), pat.HashedToken)
+ if err != nil {
+ t.Fatalf("Error when getting token ID by hashed token: %s", err)
+ }
if tokenID == "" {
t.Fatal("GetTokenIDByHashedToken failed after adding PAT")
@@ -71,11 +73,12 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
assert.Equal(t, pat.ID, tokenID)
- userID := fileStore.TokenID2UserID[tokenID]
- if userID == "" {
- t.Fatal("GetUserByTokenId failed after adding PAT")
+ user, err := am.Store.GetUserByTokenID(context.Background(), tokenID)
+ if err != nil {
+ t.Fatalf("Error when getting user by token ID: %s", err)
}
- assert.Equal(t, mockUserID, userID)
+
+ assert.Equal(t, mockUserID, user.Id)
}
func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
@@ -192,9 +195,12 @@ func TestUser_DeletePAT(t *testing.T) {
t.Fatalf("Error when adding PAT to user: %s", err)
}
- assert.Nil(t, store.Accounts[mockAccountID].Users[mockUserID].PATs[mockTokenID1])
- assert.Empty(t, store.HashedPAT2TokenID[mockToken1])
- assert.Empty(t, store.TokenID2UserID[mockTokenID1])
+ account, err = store.GetAccount(context.Background(), mockAccountID)
+ if err != nil {
+ t.Fatalf("Error when getting account: %s", err)
+ }
+
+ assert.Nil(t, account.Users[mockUserID].PATs[mockTokenID1])
}
func TestUser_GetPAT(t *testing.T) {
@@ -353,13 +359,16 @@ func TestUser_CreateServiceUser(t *testing.T) {
t.Fatalf("Error when creating service user: %s", err)
}
- assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users))
- assert.NotNil(t, store.Accounts[mockAccountID].Users[user.ID])
- assert.True(t, store.Accounts[mockAccountID].Users[user.ID].IsServiceUser)
- assert.Equal(t, mockServiceUserName, store.Accounts[mockAccountID].Users[user.ID].ServiceUserName)
- assert.Equal(t, UserRole(mockRole), store.Accounts[mockAccountID].Users[user.ID].Role)
- assert.Equal(t, []string{"group1", "group2"}, store.Accounts[mockAccountID].Users[user.ID].AutoGroups)
- assert.Equal(t, map[string]*PersonalAccessToken{}, store.Accounts[mockAccountID].Users[user.ID].PATs)
+ account, err = store.GetAccount(context.Background(), mockAccountID)
+ assert.NoError(t, err)
+
+ assert.Equal(t, 2, len(account.Users))
+ assert.NotNil(t, account.Users[user.ID])
+ assert.True(t, account.Users[user.ID].IsServiceUser)
+ assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName)
+ assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role)
+ assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups)
+ assert.Equal(t, map[string]*PersonalAccessToken{}, account.Users[user.ID].PATs)
assert.Zero(t, user.Email)
assert.True(t, user.IsServiceUser)
@@ -397,12 +406,15 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
t.Fatalf("Error when creating user: %s", err)
}
+ account, err = store.GetAccount(context.Background(), mockAccountID)
+ assert.NoError(t, err)
+
assert.True(t, user.IsServiceUser)
- assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users))
- assert.True(t, store.Accounts[mockAccountID].Users[user.ID].IsServiceUser)
- assert.Equal(t, mockServiceUserName, store.Accounts[mockAccountID].Users[user.ID].ServiceUserName)
- assert.Equal(t, UserRole(mockRole), store.Accounts[mockAccountID].Users[user.ID].Role)
- assert.Equal(t, []string{"group1", "group2"}, store.Accounts[mockAccountID].Users[user.ID].AutoGroups)
+ assert.Equal(t, 2, len(account.Users))
+ assert.True(t, account.Users[user.ID].IsServiceUser)
+ assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName)
+ assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role)
+ assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups)
assert.Equal(t, mockServiceUserName, user.Name)
assert.Equal(t, mockRole, user.Role)
@@ -553,12 +565,15 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockServiceUserID)
tt.assertErrFunc(t, err, tt.assertErrMessage)
+ account, err2 := store.GetAccount(context.Background(), mockAccountID)
+ assert.NoError(t, err2)
+
if err != nil {
- assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users))
- assert.NotNil(t, store.Accounts[mockAccountID].Users[mockServiceUserID])
+ assert.Equal(t, 2, len(account.Users))
+ assert.NotNil(t, account.Users[mockServiceUserID])
} else {
- assert.Equal(t, 1, len(store.Accounts[mockAccountID].Users))
- assert.Nil(t, store.Accounts[mockAccountID].Users[mockServiceUserID])
+ assert.Equal(t, 1, len(account.Users))
+ assert.Nil(t, account.Users[mockServiceUserID])
}
})
}
@@ -801,10 +816,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
assert.NoError(t, err)
}
- accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "")
- assert.NoError(t, err)
-
- acc, err := am.Store.GetAccount(context.Background(), accID)
+ acc, err := am.Store.GetAccount(context.Background(), account.Id)
assert.NoError(t, err)
for _, id := range tc.expectedDeleted {
diff --git a/util/file.go b/util/file.go
index 8355488c9..ecaecd222 100644
--- a/util/file.go
+++ b/util/file.go
@@ -1,11 +1,15 @@
package util
import (
+ "bytes"
"context"
"encoding/json"
+ "fmt"
"io"
"os"
"path/filepath"
+ "strings"
+ "text/template"
log "github.com/sirupsen/logrus"
)
@@ -160,6 +164,55 @@ func ReadJson(file string, res interface{}) (interface{}, error) {
return res, nil
}
+// ReadJsonWithEnvSub reads JSON config file and maps to a provided interface with environment variable substitution
+func ReadJsonWithEnvSub(file string, res interface{}) (interface{}, error) {
+ envVars := getEnvMap()
+
+ f, err := os.Open(file)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+
+ bs, err := io.ReadAll(f)
+ if err != nil {
+ return nil, err
+ }
+
+ t, err := template.New("").Parse(string(bs))
+ if err != nil {
+ return nil, fmt.Errorf("error parsing template: %v", err)
+ }
+
+ var output bytes.Buffer
+ // Execute the template, substituting environment variables
+ err = t.Execute(&output, envVars)
+ if err != nil {
+ return nil, fmt.Errorf("error executing template: %v", err)
+ }
+
+ err = json.Unmarshal(output.Bytes(), &res)
+ if err != nil {
+ return nil, fmt.Errorf("failed parsing Json file after template was executed, err: %v", err)
+ }
+
+ return res, nil
+}
+
+// getEnvMap Convert the output of os.Environ() to a map
+func getEnvMap() map[string]string {
+ envMap := make(map[string]string)
+
+ for _, env := range os.Environ() {
+ parts := strings.SplitN(env, "=", 2)
+ if len(parts) == 2 {
+ envMap[parts[0]] = parts[1]
+ }
+ }
+
+ return envMap
+}
+
// CopyFileContents copies contents of the given src file to the dst file
func CopyFileContents(src, dst string) (err error) {
in, err := os.Open(src)
diff --git a/util/file_suite_test.go b/util/file_suite_test.go
new file mode 100644
index 000000000..3de7db49b
--- /dev/null
+++ b/util/file_suite_test.go
@@ -0,0 +1,126 @@
+package util_test
+
+import (
+ "crypto/md5"
+ "encoding/hex"
+ "io"
+ "os"
+
+ . "github.com/onsi/ginkgo"
+ . "github.com/onsi/gomega"
+
+ "github.com/netbirdio/netbird/util"
+)
+
+var _ = Describe("Client", func() {
+
+ var (
+ tmpDir string
+ )
+
+ type TestConfig struct {
+ SomeMap map[string]string
+ SomeArray []string
+ SomeField int
+ }
+
+ BeforeEach(func() {
+ var err error
+ tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*")
+ Expect(err).NotTo(HaveOccurred())
+ })
+
+ AfterEach(func() {
+ err := os.RemoveAll(tmpDir)
+ Expect(err).NotTo(HaveOccurred())
+ })
+
+ Describe("Config", func() {
+ Context("in JSON format", func() {
+ It("should be written and read successfully", func() {
+
+ m := make(map[string]string)
+ m["key1"] = "value1"
+ m["key2"] = "value2"
+
+ arr := []string{"value1", "value2"}
+
+ written := &TestConfig{
+ SomeMap: m,
+ SomeArray: arr,
+ SomeField: 99,
+ }
+
+ err := util.WriteJson(tmpDir+"/testconfig.json", written)
+ Expect(err).NotTo(HaveOccurred())
+
+ read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
+ Expect(err).NotTo(HaveOccurred())
+ Expect(read).NotTo(BeNil())
+ Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"]))
+ Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"]))
+ Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr))
+ Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField))
+
+ })
+ })
+ })
+
+ Describe("Copying file contents", func() {
+ Context("from one file to another", func() {
+ It("should be successful", func() {
+
+ src := tmpDir + "/copytest_src"
+ dst := tmpDir + "/copytest_dst"
+
+ err := util.WriteJson(src, []string{"1", "2", "3"})
+ Expect(err).NotTo(HaveOccurred())
+
+ err = util.CopyFileContents(src, dst)
+ Expect(err).NotTo(HaveOccurred())
+
+ hashSrc := md5.New()
+ hashDst := md5.New()
+
+ srcFile, err := os.Open(src)
+ Expect(err).NotTo(HaveOccurred())
+
+ dstFile, err := os.Open(dst)
+ Expect(err).NotTo(HaveOccurred())
+
+ _, err = io.Copy(hashSrc, srcFile)
+ Expect(err).NotTo(HaveOccurred())
+
+ _, err = io.Copy(hashDst, dstFile)
+ Expect(err).NotTo(HaveOccurred())
+
+ err = srcFile.Close()
+ Expect(err).NotTo(HaveOccurred())
+
+ err = dstFile.Close()
+ Expect(err).NotTo(HaveOccurred())
+
+ Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16])))
+ })
+ })
+ })
+
+ Describe("Handle config file without full path", func() {
+ Context("config file handling", func() {
+ It("should be successful", func() {
+ written := &TestConfig{
+ SomeField: 123,
+ }
+ cfgFile := "test_cfg.json"
+ defer os.Remove(cfgFile)
+
+ err := util.WriteJson(cfgFile, written)
+ Expect(err).NotTo(HaveOccurred())
+
+ read, err := util.ReadJson(cfgFile, &TestConfig{})
+ Expect(err).NotTo(HaveOccurred())
+ Expect(read).NotTo(BeNil())
+ })
+ })
+ })
+})
diff --git a/util/file_test.go b/util/file_test.go
index 3de7db49b..1330e738e 100644
--- a/util/file_test.go
+++ b/util/file_test.go
@@ -1,126 +1,198 @@
-package util_test
+package util
import (
- "crypto/md5"
- "encoding/hex"
- "io"
"os"
-
- . "github.com/onsi/ginkgo"
- . "github.com/onsi/gomega"
-
- "github.com/netbirdio/netbird/util"
+ "reflect"
+ "strings"
+ "testing"
)
-var _ = Describe("Client", func() {
-
- var (
- tmpDir string
- )
-
- type TestConfig struct {
- SomeMap map[string]string
- SomeArray []string
- SomeField int
+func TestReadJsonWithEnvSub(t *testing.T) {
+ type Config struct {
+ CertFile string `json:"CertFile"`
+ Credentials string `json:"Credentials"`
+ NestedOption struct {
+ URL string `json:"URL"`
+ } `json:"NestedOption"`
}
- BeforeEach(func() {
- var err error
- tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*")
- Expect(err).NotTo(HaveOccurred())
- })
+ type testCase struct {
+ name string
+ envVars map[string]string
+ jsonTemplate string
+ expectedResult Config
+ expectError bool
+ errorContains string
+ }
- AfterEach(func() {
- err := os.RemoveAll(tmpDir)
- Expect(err).NotTo(HaveOccurred())
- })
+ tests := []testCase{
+ {
+ name: "All environment variables set",
+ envVars: map[string]string{
+ "CERT_FILE": "/etc/certs/env_cert.crt",
+ "CREDENTIALS": "env_credentials",
+ "URL": "https://env.testing.com",
+ },
+ jsonTemplate: `{
+ "CertFile": "{{ .CERT_FILE }}",
+ "Credentials": "{{ .CREDENTIALS }}",
+ "NestedOption": {
+ "URL": "{{ .URL }}"
+ }
+ }`,
+ expectedResult: Config{
+ CertFile: "/etc/certs/env_cert.crt",
+ Credentials: "env_credentials",
+ NestedOption: struct {
+ URL string `json:"URL"`
+ }{
+ URL: "https://env.testing.com",
+ },
+ },
+ expectError: false,
+ },
+ {
+ name: "Missing environment variable",
+ envVars: map[string]string{
+ "CERT_FILE": "/etc/certs/env_cert.crt",
+ "CREDENTIALS": "env_credentials",
+ // "URL" is intentionally missing
+ },
+ jsonTemplate: `{
+ "CertFile": "{{ .CERT_FILE }}",
+ "Credentials": "{{ .CREDENTIALS }}",
+ "NestedOption": {
+ "URL": "{{ .URL }}"
+ }
+ }`,
+ expectedResult: Config{
+ CertFile: "/etc/certs/env_cert.crt",
+ Credentials: "env_credentials",
+ NestedOption: struct {
+ URL string `json:"URL"`
+ }{
+ URL: "",
+ },
+ },
+ expectError: false,
+ },
+ {
+ name: "Invalid JSON template",
+ envVars: map[string]string{
+ "CERT_FILE": "/etc/certs/env_cert.crt",
+ "CREDENTIALS": "env_credentials",
+ "URL": "https://env.testing.com",
+ },
+ jsonTemplate: `{
+ "CertFile": "{{ .CERT_FILE }}",
+ "Credentials": "{{ .CREDENTIALS }",
+ "NestedOption": {
+ "URL": "{{ .URL }}"
+ }
+ }`, // Note the missing closing brace in "{{ .CREDENTIALS }"
+ expectedResult: Config{},
+ expectError: true,
+ errorContains: "unexpected \"}\" in operand",
+ },
+ {
+ name: "No substitutions",
+ envVars: map[string]string{
+ "CERT_FILE": "/etc/certs/env_cert.crt",
+ "CREDENTIALS": "env_credentials",
+ "URL": "https://env.testing.com",
+ },
+ jsonTemplate: `{
+ "CertFile": "/etc/certs/cert.crt",
+ "Credentials": "admnlknflkdasdf",
+ "NestedOption" : {
+ "URL": "https://testing.com"
+ }
+ }`,
+ expectedResult: Config{
+ CertFile: "/etc/certs/cert.crt",
+ Credentials: "admnlknflkdasdf",
+ NestedOption: struct {
+ URL string `json:"URL"`
+ }{
+ URL: "https://testing.com",
+ },
+ },
+ expectError: false,
+ },
+ {
+ name: "Should fail when Invalid characters in variables",
+ envVars: map[string]string{
+ "CERT_FILE": `"/etc/certs/"cert".crt"`,
+ "CREDENTIALS": `env_credentia{ls}`,
+ "URL": `https://env.testing.com?param={{value}}`,
+ },
+ jsonTemplate: `{
+ "CertFile": "{{ .CERT_FILE }}",
+ "Credentials": "{{ .CREDENTIALS }}",
+ "NestedOption": {
+ "URL": "{{ .URL }}"
+ }
+ }`,
+ expectedResult: Config{
+ CertFile: `"/etc/certs/"cert".crt"`,
+ Credentials: `env_credentia{ls}`,
+ NestedOption: struct {
+ URL string `json:"URL"`
+ }{
+ URL: `https://env.testing.com?param={{value}}`,
+ },
+ },
+ expectError: true,
+ },
+ }
- Describe("Config", func() {
- Context("in JSON format", func() {
- It("should be written and read successfully", func() {
+ for _, tc := range tests {
+ tc := tc
+ t.Run(tc.name, func(t *testing.T) {
+ for key, value := range tc.envVars {
+ t.Setenv(key, value)
+ }
- m := make(map[string]string)
- m["key1"] = "value1"
- m["key2"] = "value2"
+ tempFile, err := os.CreateTemp("", "config*.json")
+ if err != nil {
+ t.Fatalf("Failed to create temp file: %v", err)
+ }
- arr := []string{"value1", "value2"}
-
- written := &TestConfig{
- SomeMap: m,
- SomeArray: arr,
- SomeField: 99,
+ defer func() {
+ err = os.Remove(tempFile.Name())
+ if err != nil {
+ t.Logf("Failed to remove temp file: %v", err)
}
+ }()
- err := util.WriteJson(tmpDir+"/testconfig.json", written)
- Expect(err).NotTo(HaveOccurred())
+ _, err = tempFile.WriteString(tc.jsonTemplate)
+ if err != nil {
+ t.Fatalf("Failed to write to temp file: %v", err)
+ }
+ err = tempFile.Close()
+ if err != nil {
+ t.Fatalf("Failed to close temp file: %v", err)
+ }
- read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
- Expect(err).NotTo(HaveOccurred())
- Expect(read).NotTo(BeNil())
- Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"]))
- Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"]))
- Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr))
- Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField))
+ var result Config
- })
- })
- })
+ _, err = ReadJsonWithEnvSub(tempFile.Name(), &result)
- Describe("Copying file contents", func() {
- Context("from one file to another", func() {
- It("should be successful", func() {
-
- src := tmpDir + "/copytest_src"
- dst := tmpDir + "/copytest_dst"
-
- err := util.WriteJson(src, []string{"1", "2", "3"})
- Expect(err).NotTo(HaveOccurred())
-
- err = util.CopyFileContents(src, dst)
- Expect(err).NotTo(HaveOccurred())
-
- hashSrc := md5.New()
- hashDst := md5.New()
-
- srcFile, err := os.Open(src)
- Expect(err).NotTo(HaveOccurred())
-
- dstFile, err := os.Open(dst)
- Expect(err).NotTo(HaveOccurred())
-
- _, err = io.Copy(hashSrc, srcFile)
- Expect(err).NotTo(HaveOccurred())
-
- _, err = io.Copy(hashDst, dstFile)
- Expect(err).NotTo(HaveOccurred())
-
- err = srcFile.Close()
- Expect(err).NotTo(HaveOccurred())
-
- err = dstFile.Close()
- Expect(err).NotTo(HaveOccurred())
-
- Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16])))
- })
- })
- })
-
- Describe("Handle config file without full path", func() {
- Context("config file handling", func() {
- It("should be successful", func() {
- written := &TestConfig{
- SomeField: 123,
+ if tc.expectError {
+ if err == nil {
+ t.Fatalf("Expected error but got none")
}
- cfgFile := "test_cfg.json"
- defer os.Remove(cfgFile)
-
- err := util.WriteJson(cfgFile, written)
- Expect(err).NotTo(HaveOccurred())
-
- read, err := util.ReadJson(cfgFile, &TestConfig{})
- Expect(err).NotTo(HaveOccurred())
- Expect(read).NotTo(BeNil())
- })
+ if !strings.Contains(err.Error(), tc.errorContains) {
+ t.Errorf("Expected error containing '%s', but got '%v'", tc.errorContains, err)
+ }
+ } else {
+ if err != nil {
+ t.Fatalf("ReadJsonWithEnvSub failed: %v", err)
+ }
+ if !reflect.DeepEqual(result, tc.expectedResult) {
+ t.Errorf("Result does not match expected.\nGot: %+v\nExpected: %+v", result, tc.expectedResult)
+ }
+ }
})
- })
-})
+ }
+}