mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-13 18:31:18 +01:00
Merge branch 'feature/optimize-network-map-updates' into feature/validate-group-association
# Conflicts: # management/server/account.go
This commit is contained in:
commit
57f7f43ecb
2
.github/workflows/golang-test-linux.yml
vendored
2
.github/workflows/golang-test-linux.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
matrix:
|
matrix:
|
||||||
arch: [ '386','amd64' ]
|
arch: [ '386','amd64' ]
|
||||||
store: [ 'sqlite', 'postgres']
|
store: [ 'sqlite', 'postgres']
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
|
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@ -20,7 +20,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
release:
|
release:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-22.04
|
||||||
env:
|
env:
|
||||||
flags: ""
|
flags: ""
|
||||||
steps:
|
steps:
|
||||||
|
@ -49,6 +49,8 @@
|
|||||||
|
|
||||||
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
|
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
|
||||||
|
|
||||||
|
### NetBird on Lawrence Systems (Video)
|
||||||
|
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
||||||
|
|
||||||
### Key features
|
### Key features
|
||||||
|
|
||||||
@ -62,6 +64,7 @@
|
|||||||
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||||||
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
|
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
|
||||||
| | | | | <ul><li> - \[x] Docker </ul></li> |
|
| | | | | <ul><li> - \[x] Docker </ul></li> |
|
||||||
|
|
||||||
### Quickstart with NetBird Cloud
|
### Quickstart with NetBird Cloud
|
||||||
|
|
||||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
||||||
|
@ -3,7 +3,6 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -34,18 +33,12 @@ func startTestingServices(t *testing.T) string {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
testDir := t.TempDir()
|
|
||||||
config.Datadir = testDir
|
|
||||||
err = util.CopyFileContents("../testdata/store.json", filepath.Join(testDir, "store.json"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, signalLis := startSignal(t)
|
_, signalLis := startSignal(t)
|
||||||
signalAddr := signalLis.Addr().String()
|
signalAddr := signalLis.Addr().String()
|
||||||
config.Signal.URI = signalAddr
|
config.Signal.URI = signalAddr
|
||||||
|
|
||||||
_, mgmLis := startManagement(t, config)
|
_, mgmLis := startManagement(t, config, "../testdata/store.sqlite")
|
||||||
mgmAddr := mgmLis.Addr().String()
|
mgmAddr := mgmLis.Addr().String()
|
||||||
return mgmAddr
|
return mgmAddr
|
||||||
}
|
}
|
||||||
@ -70,7 +63,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
|||||||
return s, lis
|
return s, lis
|
||||||
}
|
}
|
||||||
|
|
||||||
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
|
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", ":0")
|
lis, err := net.Listen("tcp", ":0")
|
||||||
@ -78,7 +71,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
|
store, cleanUp, err := mgmt.NewTestStoreFromSqlite(context.Background(), testFile, t.TempDir())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -269,12 +269,6 @@ func (c *ConnectClient) run(
|
|||||||
checks := loginResp.GetChecks()
|
checks := loginResp.GetChecks()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
if c.engine != nil && c.engine.ctx.Err() != nil {
|
|
||||||
log.Info("Stopping Netbird Engine")
|
|
||||||
if err := c.engine.Stop(); err != nil {
|
|
||||||
log.Errorf("Failed to stop engine: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
|
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
|
||||||
|
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
@ -294,6 +288,15 @@ func (c *ConnectClient) run(
|
|||||||
}
|
}
|
||||||
|
|
||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
|
c.engineMutex.Lock()
|
||||||
|
if c.engine != nil && c.engine.wgInterface != nil {
|
||||||
|
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
|
||||||
|
if err := c.engine.Stop(); err != nil {
|
||||||
|
log.Errorf("Failed to stop engine: %v", err)
|
||||||
|
}
|
||||||
|
c.engine = nil
|
||||||
|
}
|
||||||
|
c.engineMutex.Unlock()
|
||||||
c.statusRecorder.ClientTeardown()
|
c.statusRecorder.ClientTeardown()
|
||||||
|
|
||||||
backOff.Reset()
|
backOff.Reset()
|
||||||
|
@ -251,6 +251,13 @@ func (e *Engine) Stop() error {
|
|||||||
}
|
}
|
||||||
log.Info("Network monitor: stopped")
|
log.Info("Network monitor: stopped")
|
||||||
|
|
||||||
|
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||||
|
e.stopDNSServer()
|
||||||
|
|
||||||
|
if e.routeManager != nil {
|
||||||
|
e.routeManager.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
err := e.removeAllPeers()
|
err := e.removeAllPeers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to remove all peers: %s", err)
|
return fmt.Errorf("failed to remove all peers: %s", err)
|
||||||
@ -1116,18 +1123,12 @@ func (e *Engine) close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
|
||||||
e.stopDNSServer()
|
|
||||||
|
|
||||||
if e.routeManager != nil {
|
|
||||||
e.routeManager.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||||
if e.wgInterface != nil {
|
if e.wgInterface != nil {
|
||||||
if err := e.wgInterface.Close(); err != nil {
|
if err := e.wgInterface.Close(); err != nil {
|
||||||
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
||||||
}
|
}
|
||||||
|
e.wgInterface = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isNil(e.sshServer) {
|
if !isNil(e.sshServer) {
|
||||||
@ -1395,7 +1396,7 @@ func (e *Engine) startNetworkMonitor() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set a new timer to debounce rapid network changes
|
// Set a new timer to debounce rapid network changes
|
||||||
debounceTimer = time.AfterFunc(1*time.Second, func() {
|
debounceTimer = time.AfterFunc(2*time.Second, func() {
|
||||||
// This function is called after the debounce period
|
// This function is called after the debounce period
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
defer mu.Unlock()
|
defer mu.Unlock()
|
||||||
@ -1426,6 +1427,11 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) stopDNSServer() {
|
func (e *Engine) stopDNSServer() {
|
||||||
|
if e.dnsServer == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
e.dnsServer.Stop()
|
||||||
|
e.dnsServer = nil
|
||||||
err := fmt.Errorf("DNS server stopped")
|
err := fmt.Errorf("DNS server stopped")
|
||||||
nsGroupStates := e.statusRecorder.GetDNSStates()
|
nsGroupStates := e.statusRecorder.GetDNSStates()
|
||||||
for i := range nsGroupStates {
|
for i := range nsGroupStates {
|
||||||
@ -1433,10 +1439,6 @@ func (e *Engine) stopDNSServer() {
|
|||||||
nsGroupStates[i].Error = err
|
nsGroupStates[i].Error = err
|
||||||
}
|
}
|
||||||
e.statusRecorder.UpdateDNSStates(nsGroupStates)
|
e.statusRecorder.UpdateDNSStates(nsGroupStates)
|
||||||
if e.dnsServer != nil {
|
|
||||||
e.dnsServer.Stop()
|
|
||||||
e.dnsServer = nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// isChecksEqual checks if two slices of checks are equal.
|
// isChecksEqual checks if two slices of checks are equal.
|
||||||
|
@ -6,7 +6,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -824,20 +823,6 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
func TestEngine_MultiplePeers(t *testing.T) {
|
func TestEngine_MultiplePeers(t *testing.T) {
|
||||||
// log.SetLevel(log.DebugLevel)
|
// log.SetLevel(log.DebugLevel)
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
|
|
||||||
err := util.CopyFileContents("../testdata/store.json", filepath.Join(dir, "store.json"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = os.Remove(filepath.Join(dir, "store.json")) //nolint
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@ -847,7 +832,7 @@ func TestEngine_MultiplePeers(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer sigServer.Stop()
|
defer sigServer.Stop()
|
||||||
mgmtServer, mgmtAddr, err := startManagement(t, dir)
|
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sqlite")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
@ -1070,7 +1055,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
|||||||
return s, lis.Addr().String(), nil
|
return s, lis.Addr().String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) {
|
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
config := &server.Config{
|
config := &server.Config{
|
||||||
@ -1095,7 +1080,7 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
|
|||||||
}
|
}
|
||||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||||
|
|
||||||
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
|
store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), testFile, config.Datadir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
@ -32,6 +32,8 @@ const (
|
|||||||
connPriorityRelay ConnPriority = 1
|
connPriorityRelay ConnPriority = 1
|
||||||
connPriorityICETurn ConnPriority = 1
|
connPriorityICETurn ConnPriority = 1
|
||||||
connPriorityICEP2P ConnPriority = 2
|
connPriorityICEP2P ConnPriority = 2
|
||||||
|
|
||||||
|
reconnectMaxElapsedTime = 30 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
type WgConfig struct {
|
type WgConfig struct {
|
||||||
@ -83,6 +85,7 @@ type Conn struct {
|
|||||||
wgProxyICE wgproxy.Proxy
|
wgProxyICE wgproxy.Proxy
|
||||||
wgProxyRelay wgproxy.Proxy
|
wgProxyRelay wgproxy.Proxy
|
||||||
signaler *Signaler
|
signaler *Signaler
|
||||||
|
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
relayManager *relayClient.Manager
|
relayManager *relayClient.Manager
|
||||||
allowedIPsIP string
|
allowedIPsIP string
|
||||||
handshaker *Handshaker
|
handshaker *Handshaker
|
||||||
@ -108,6 +111,8 @@ type Conn struct {
|
|||||||
// for reconnection operations
|
// for reconnection operations
|
||||||
iCEDisconnected chan bool
|
iCEDisconnected chan bool
|
||||||
relayDisconnected chan bool
|
relayDisconnected chan bool
|
||||||
|
connMonitor *ConnMonitor
|
||||||
|
reconnectCh <-chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConn creates a new not opened Conn to the remote peer.
|
// NewConn creates a new not opened Conn to the remote peer.
|
||||||
@ -123,21 +128,31 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
|||||||
connLog := log.WithField("peer", config.Key)
|
connLog := log.WithField("peer", config.Key)
|
||||||
|
|
||||||
var conn = &Conn{
|
var conn = &Conn{
|
||||||
log: connLog,
|
log: connLog,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: ctxCancel,
|
ctxCancel: ctxCancel,
|
||||||
config: config,
|
config: config,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
wgProxyFactory: wgProxyFactory,
|
wgProxyFactory: wgProxyFactory,
|
||||||
signaler: signaler,
|
signaler: signaler,
|
||||||
relayManager: relayManager,
|
iFaceDiscover: iFaceDiscover,
|
||||||
allowedIPsIP: allowedIPsIP.String(),
|
relayManager: relayManager,
|
||||||
statusRelay: NewAtomicConnStatus(),
|
allowedIPsIP: allowedIPsIP.String(),
|
||||||
statusICE: NewAtomicConnStatus(),
|
statusRelay: NewAtomicConnStatus(),
|
||||||
|
statusICE: NewAtomicConnStatus(),
|
||||||
|
|
||||||
iCEDisconnected: make(chan bool, 1),
|
iCEDisconnected: make(chan bool, 1),
|
||||||
relayDisconnected: make(chan bool, 1),
|
relayDisconnected: make(chan bool, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conn.connMonitor, conn.reconnectCh = NewConnMonitor(
|
||||||
|
signaler,
|
||||||
|
iFaceDiscover,
|
||||||
|
config,
|
||||||
|
conn.relayDisconnected,
|
||||||
|
conn.iCEDisconnected,
|
||||||
|
)
|
||||||
|
|
||||||
rFns := WorkerRelayCallbacks{
|
rFns := WorkerRelayCallbacks{
|
||||||
OnConnReady: conn.relayConnectionIsReady,
|
OnConnReady: conn.relayConnectionIsReady,
|
||||||
OnDisconnected: conn.onWorkerRelayStateDisconnected,
|
OnDisconnected: conn.onWorkerRelayStateDisconnected,
|
||||||
@ -200,6 +215,8 @@ func (conn *Conn) startHandshakeAndReconnect() {
|
|||||||
conn.log.Errorf("failed to send initial offer: %v", err)
|
conn.log.Errorf("failed to send initial offer: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go conn.connMonitor.Start(conn.ctx)
|
||||||
|
|
||||||
if conn.workerRelay.IsController() {
|
if conn.workerRelay.IsController() {
|
||||||
conn.reconnectLoopWithRetry()
|
conn.reconnectLoopWithRetry()
|
||||||
} else {
|
} else {
|
||||||
@ -309,12 +326,14 @@ func (conn *Conn) reconnectLoopWithRetry() {
|
|||||||
// With it, we can decrease to send necessary offer
|
// With it, we can decrease to send necessary offer
|
||||||
select {
|
select {
|
||||||
case <-conn.ctx.Done():
|
case <-conn.ctx.Done():
|
||||||
|
return
|
||||||
case <-time.After(3 * time.Second):
|
case <-time.After(3 * time.Second):
|
||||||
}
|
}
|
||||||
|
|
||||||
ticker := conn.prepareExponentTicker()
|
ticker := conn.prepareExponentTicker()
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case t := <-ticker.C:
|
case t := <-ticker.C:
|
||||||
@ -342,20 +361,11 @@ func (conn *Conn) reconnectLoopWithRetry() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Errorf("failed to do handshake: %v", err)
|
conn.log.Errorf("failed to do handshake: %v", err)
|
||||||
}
|
}
|
||||||
case changed := <-conn.relayDisconnected:
|
|
||||||
if !changed {
|
case <-conn.reconnectCh:
|
||||||
continue
|
|
||||||
}
|
|
||||||
conn.log.Debugf("Relay state changed, reset reconnect timer")
|
|
||||||
ticker.Stop()
|
|
||||||
ticker = conn.prepareExponentTicker()
|
|
||||||
case changed := <-conn.iCEDisconnected:
|
|
||||||
if !changed {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
conn.log.Debugf("ICE state changed, reset reconnect timer")
|
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
ticker = conn.prepareExponentTicker()
|
ticker = conn.prepareExponentTicker()
|
||||||
|
|
||||||
case <-conn.ctx.Done():
|
case <-conn.ctx.Done():
|
||||||
conn.log.Debugf("context is done, stop reconnect loop")
|
conn.log.Debugf("context is done, stop reconnect loop")
|
||||||
return
|
return
|
||||||
@ -366,10 +376,10 @@ func (conn *Conn) reconnectLoopWithRetry() {
|
|||||||
func (conn *Conn) prepareExponentTicker() *backoff.Ticker {
|
func (conn *Conn) prepareExponentTicker() *backoff.Ticker {
|
||||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||||
InitialInterval: 800 * time.Millisecond,
|
InitialInterval: 800 * time.Millisecond,
|
||||||
RandomizationFactor: 0.01,
|
RandomizationFactor: 0.1,
|
||||||
Multiplier: 2,
|
Multiplier: 2,
|
||||||
MaxInterval: conn.config.Timeout,
|
MaxInterval: conn.config.Timeout,
|
||||||
MaxElapsedTime: 0,
|
MaxElapsedTime: reconnectMaxElapsedTime,
|
||||||
Stop: backoff.Stop,
|
Stop: backoff.Stop,
|
||||||
Clock: backoff.SystemClock,
|
Clock: backoff.SystemClock,
|
||||||
}, conn.ctx)
|
}, conn.ctx)
|
||||||
|
212
client/internal/peer/conn_monitor.go
Normal file
212
client/internal/peer/conn_monitor.go
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/ice/v3"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
signalerMonitorPeriod = 5 * time.Second
|
||||||
|
candidatesMonitorPeriod = 5 * time.Minute
|
||||||
|
candidateGatheringTimeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConnMonitor struct {
|
||||||
|
signaler *Signaler
|
||||||
|
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
|
config ConnConfig
|
||||||
|
relayDisconnected chan bool
|
||||||
|
iCEDisconnected chan bool
|
||||||
|
reconnectCh chan struct{}
|
||||||
|
currentCandidates []ice.Candidate
|
||||||
|
candidatesMu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnMonitor(signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, config ConnConfig, relayDisconnected, iCEDisconnected chan bool) (*ConnMonitor, <-chan struct{}) {
|
||||||
|
reconnectCh := make(chan struct{}, 1)
|
||||||
|
cm := &ConnMonitor{
|
||||||
|
signaler: signaler,
|
||||||
|
iFaceDiscover: iFaceDiscover,
|
||||||
|
config: config,
|
||||||
|
relayDisconnected: relayDisconnected,
|
||||||
|
iCEDisconnected: iCEDisconnected,
|
||||||
|
reconnectCh: reconnectCh,
|
||||||
|
}
|
||||||
|
return cm, reconnectCh
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ConnMonitor) Start(ctx context.Context) {
|
||||||
|
signalerReady := make(chan struct{}, 1)
|
||||||
|
go cm.monitorSignalerReady(ctx, signalerReady)
|
||||||
|
|
||||||
|
localCandidatesChanged := make(chan struct{}, 1)
|
||||||
|
go cm.monitorLocalCandidatesChanged(ctx, localCandidatesChanged)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case changed := <-cm.relayDisconnected:
|
||||||
|
if !changed {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Debugf("Relay state changed, triggering reconnect")
|
||||||
|
cm.triggerReconnect()
|
||||||
|
|
||||||
|
case changed := <-cm.iCEDisconnected:
|
||||||
|
if !changed {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Debugf("ICE state changed, triggering reconnect")
|
||||||
|
cm.triggerReconnect()
|
||||||
|
|
||||||
|
case <-signalerReady:
|
||||||
|
log.Debugf("Signaler became ready, triggering reconnect")
|
||||||
|
cm.triggerReconnect()
|
||||||
|
|
||||||
|
case <-localCandidatesChanged:
|
||||||
|
log.Debugf("Local candidates changed, triggering reconnect")
|
||||||
|
cm.triggerReconnect()
|
||||||
|
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ConnMonitor) monitorSignalerReady(ctx context.Context, signalerReady chan<- struct{}) {
|
||||||
|
if cm.signaler == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(signalerMonitorPeriod)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
lastReady := true
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
currentReady := cm.signaler.Ready()
|
||||||
|
if !lastReady && currentReady {
|
||||||
|
select {
|
||||||
|
case signalerReady <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lastReady = currentReady
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ConnMonitor) monitorLocalCandidatesChanged(ctx context.Context, localCandidatesChanged chan<- struct{}) {
|
||||||
|
ufrag, pwd, err := generateICECredentials()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to generate ICE credentials: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(candidatesMonitorPeriod)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := cm.handleCandidateTick(ctx, localCandidatesChanged, ufrag, pwd); err != nil {
|
||||||
|
log.Warnf("Failed to handle candidate tick: %v", err)
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ConnMonitor) handleCandidateTick(ctx context.Context, localCandidatesChanged chan<- struct{}, ufrag string, pwd string) error {
|
||||||
|
log.Debugf("Gathering ICE candidates")
|
||||||
|
|
||||||
|
transportNet, err := newStdNet(cm.iFaceDiscover, cm.config.ICEConfig.InterfaceBlackList)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
agent, err := newAgent(cm.config, transportNet, candidateTypesP2P(), ufrag, pwd)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create ICE agent: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := agent.Close(); err != nil {
|
||||||
|
log.Warnf("Failed to close ICE agent: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
gatherDone := make(chan struct{})
|
||||||
|
err = agent.OnCandidate(func(c ice.Candidate) {
|
||||||
|
log.Tracef("Got candidate: %v", c)
|
||||||
|
if c == nil {
|
||||||
|
close(gatherDone)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("set ICE candidate handler: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := agent.GatherCandidates(); err != nil {
|
||||||
|
return fmt.Errorf("gather ICE candidates: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return fmt.Errorf("wait for gathering: %w", ctx.Err())
|
||||||
|
case <-gatherDone:
|
||||||
|
}
|
||||||
|
|
||||||
|
candidates, err := agent.GetLocalCandidates()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get local candidates: %w", err)
|
||||||
|
}
|
||||||
|
log.Tracef("Got candidates: %v", candidates)
|
||||||
|
|
||||||
|
if changed := cm.updateCandidates(candidates); changed {
|
||||||
|
select {
|
||||||
|
case localCandidatesChanged <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ConnMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
|
||||||
|
cm.candidatesMu.Lock()
|
||||||
|
defer cm.candidatesMu.Unlock()
|
||||||
|
|
||||||
|
if len(cm.currentCandidates) != len(newCandidates) {
|
||||||
|
cm.currentCandidates = newCandidates
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, candidate := range cm.currentCandidates {
|
||||||
|
if candidate.Address() != newCandidates[i].Address() {
|
||||||
|
cm.currentCandidates = newCandidates
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *ConnMonitor) triggerReconnect() {
|
||||||
|
select {
|
||||||
|
case cm.reconnectCh <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
@ -6,6 +6,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (w *WorkerICE) newStdNet() (*stdnet.Net, error) {
|
func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||||
return stdnet.NewNet(w.config.ICEConfig.InterfaceBlackList)
|
return stdnet.NewNet(ifaceBlacklist)
|
||||||
}
|
}
|
||||||
|
@ -2,6 +2,6 @@ package peer
|
|||||||
|
|
||||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
|
||||||
func (w *WorkerICE) newStdNet() (*stdnet.Net, error) {
|
func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||||
return stdnet.NewNetWithDiscover(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList)
|
return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist)
|
||||||
}
|
}
|
||||||
|
@ -233,41 +233,16 @@ func (w *WorkerICE) Close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) {
|
func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) {
|
||||||
transportNet, err := w.newStdNet()
|
transportNet, err := newStdNet(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.log.Errorf("failed to create pion's stdnet: %s", err)
|
w.log.Errorf("failed to create pion's stdnet: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
iceKeepAlive := iceKeepAlive()
|
|
||||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
|
||||||
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
|
|
||||||
|
|
||||||
agentConfig := &ice.AgentConfig{
|
|
||||||
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
|
||||||
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
|
||||||
Urls: w.config.ICEConfig.StunTurn.Load().([]*stun.URI),
|
|
||||||
CandidateTypes: relaySupport,
|
|
||||||
InterfaceFilter: stdnet.InterfaceFilter(w.config.ICEConfig.InterfaceBlackList),
|
|
||||||
UDPMux: w.config.ICEConfig.UDPMux,
|
|
||||||
UDPMuxSrflx: w.config.ICEConfig.UDPMuxSrflx,
|
|
||||||
NAT1To1IPs: w.config.ICEConfig.NATExternalIPs,
|
|
||||||
Net: transportNet,
|
|
||||||
FailedTimeout: &failedTimeout,
|
|
||||||
DisconnectedTimeout: &iceDisconnectedTimeout,
|
|
||||||
KeepaliveInterval: &iceKeepAlive,
|
|
||||||
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
|
|
||||||
LocalUfrag: w.localUfrag,
|
|
||||||
LocalPwd: w.localPwd,
|
|
||||||
}
|
|
||||||
|
|
||||||
if w.config.ICEConfig.DisableIPv6Discovery {
|
|
||||||
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
|
|
||||||
}
|
|
||||||
|
|
||||||
w.sentExtraSrflx = false
|
w.sentExtraSrflx = false
|
||||||
agent, err := ice.NewAgent(agentConfig)
|
|
||||||
|
agent, err := newAgent(w.config, transportNet, relaySupport, w.localUfrag, w.localPwd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("create agent: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = agent.OnCandidate(w.onICECandidate)
|
err = agent.OnCandidate(w.onICECandidate)
|
||||||
@ -390,6 +365,36 @@ func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferA
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newAgent(config ConnConfig, transportNet *stdnet.Net, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) {
|
||||||
|
iceKeepAlive := iceKeepAlive()
|
||||||
|
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||||
|
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
|
||||||
|
|
||||||
|
agentConfig := &ice.AgentConfig{
|
||||||
|
MulticastDNSMode: ice.MulticastDNSModeDisabled,
|
||||||
|
NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
|
||||||
|
Urls: config.ICEConfig.StunTurn.Load().([]*stun.URI),
|
||||||
|
CandidateTypes: candidateTypes,
|
||||||
|
InterfaceFilter: stdnet.InterfaceFilter(config.ICEConfig.InterfaceBlackList),
|
||||||
|
UDPMux: config.ICEConfig.UDPMux,
|
||||||
|
UDPMuxSrflx: config.ICEConfig.UDPMuxSrflx,
|
||||||
|
NAT1To1IPs: config.ICEConfig.NATExternalIPs,
|
||||||
|
Net: transportNet,
|
||||||
|
FailedTimeout: &failedTimeout,
|
||||||
|
DisconnectedTimeout: &iceDisconnectedTimeout,
|
||||||
|
KeepaliveInterval: &iceKeepAlive,
|
||||||
|
RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
|
||||||
|
LocalUfrag: ufrag,
|
||||||
|
LocalPwd: pwd,
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.ICEConfig.DisableIPv6Discovery {
|
||||||
|
agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ice.NewAgent(agentConfig)
|
||||||
|
}
|
||||||
|
|
||||||
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
|
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
|
||||||
relatedAdd := candidate.RelatedAddress()
|
relatedAdd := candidate.RelatedAddress()
|
||||||
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
||||||
|
@ -110,7 +110,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||||
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
|
store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), "", config.Datadir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
38
client/testdata/store.json
vendored
38
client/testdata/store.json
vendored
@ -1,38 +0,0 @@
|
|||||||
{
|
|
||||||
"Accounts": {
|
|
||||||
"bf1c8084-ba50-4ce7-9439-34653001fc3b": {
|
|
||||||
"Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
|
||||||
"SetupKeys": {
|
|
||||||
"A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
|
|
||||||
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
|
||||||
"Name": "Default key",
|
|
||||||
"Type": "reusable",
|
|
||||||
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
|
|
||||||
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
|
|
||||||
"Revoked": false,
|
|
||||||
"UsedTimes": 0
|
|
||||||
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Network": {
|
|
||||||
"Id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
|
|
||||||
"Net": {
|
|
||||||
"IP": "100.64.0.0",
|
|
||||||
"Mask": "//8AAA=="
|
|
||||||
},
|
|
||||||
"Dns": null
|
|
||||||
},
|
|
||||||
"Peers": {},
|
|
||||||
"Users": {
|
|
||||||
"edafee4e-63fb-11ec-90d6-0242ac120003": {
|
|
||||||
"Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
|
|
||||||
"Role": "admin"
|
|
||||||
},
|
|
||||||
"f4f6d672-63fb-11ec-90d6-0242ac120003": {
|
|
||||||
"Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
|
|
||||||
"Role": "user"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
BIN
client/testdata/store.sqlite
vendored
Normal file
BIN
client/testdata/store.sqlite
vendored
Normal file
Binary file not shown.
@ -47,25 +47,18 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
|||||||
level, _ := log.ParseLevel("debug")
|
level, _ := log.ParseLevel("debug")
|
||||||
log.SetLevel(level)
|
log.SetLevel(level)
|
||||||
|
|
||||||
testDir := t.TempDir()
|
|
||||||
|
|
||||||
config := &mgmt.Config{}
|
config := &mgmt.Config{}
|
||||||
_, err := util.ReadJson("../server/testdata/management.json", config)
|
_, err := util.ReadJson("../server/testdata/management.json", config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
config.Datadir = testDir
|
|
||||||
err = util.CopyFileContents("../server/testdata/store.json", filepath.Join(testDir, "store.json"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", ":0")
|
lis, err := net.Listen("tcp", ":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
|
store, cleanUp, err := NewSqliteTestStore(t, context.Background(), "../server/testdata/store.sqlite")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -521,3 +514,22 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
|
|||||||
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
|
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
|
||||||
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match")
|
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewSqliteTestStore(t *testing.T, ctx context.Context, testFile string) (mgmt.Store, func(), error) {
|
||||||
|
t.Helper()
|
||||||
|
dataDir := t.TempDir()
|
||||||
|
err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
store, err := mgmt.NewSqliteStore(ctx, dataDir, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return store, func() {
|
||||||
|
store.Close(ctx)
|
||||||
|
os.Remove(filepath.Join(dataDir, "store.db"))
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
@ -475,7 +475,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle
|
|||||||
|
|
||||||
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) {
|
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) {
|
||||||
loadedConfig := &server.Config{}
|
loadedConfig := &server.Config{}
|
||||||
_, err := util.ReadJson(mgmtConfigPath, loadedConfig)
|
_, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -20,6 +20,11 @@ import (
|
|||||||
cacheStore "github.com/eko/gocache/v3/store"
|
cacheStore "github.com/eko/gocache/v3/store"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
gocache "github.com/patrickmn/go-cache"
|
||||||
|
"github.com/rs/xid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/base62"
|
"github.com/netbirdio/netbird/base62"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
@ -36,10 +41,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
gocache "github.com/patrickmn/go-cache"
|
|
||||||
"github.com/rs/xid"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -76,7 +77,8 @@ type AccountManager interface {
|
|||||||
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
|
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
|
||||||
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
|
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
|
||||||
GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
|
GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
|
||||||
GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error)
|
AccountExists(ctx context.Context, accountID string) (bool, error)
|
||||||
|
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
||||||
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||||
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||||
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
|
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||||
@ -96,6 +98,7 @@ type AccountManager interface {
|
|||||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
|
||||||
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
|
||||||
|
UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error
|
||||||
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error)
|
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error)
|
||||||
GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error)
|
GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error)
|
||||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
|
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
|
||||||
@ -842,55 +845,54 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
|
|||||||
return a.Peers[peerID]
|
return a.Peers[peerID]
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetJWTGroups updates the user's auto groups by synchronizing JWT groups.
|
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
|
||||||
// Returns true if there are changes in the JWT group membership.
|
// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
|
||||||
func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool {
|
// newly groups to create and an error if any occurred.
|
||||||
user, ok := a.Users[userID]
|
func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) {
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
existedGroupsByName := make(map[string]*nbgroup.Group)
|
existedGroupsByName := make(map[string]*nbgroup.Group)
|
||||||
for _, group := range a.Groups {
|
for _, group := range groups {
|
||||||
existedGroupsByName[group.Name] = group
|
existedGroupsByName[group.Name] = group
|
||||||
}
|
}
|
||||||
|
|
||||||
newAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, a.Groups)
|
newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups)
|
||||||
groupsToAdd := difference(groupsNames, maps.Keys(jwtGroupsMap))
|
|
||||||
groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupsNames)
|
groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap))
|
||||||
|
groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames)
|
||||||
|
|
||||||
// If no groups are added or removed, we should not sync account
|
// If no groups are added or removed, we should not sync account
|
||||||
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
|
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
|
||||||
return false
|
return false, nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
newGroupsToCreate := make([]*nbgroup.Group, 0)
|
||||||
|
|
||||||
var modified bool
|
var modified bool
|
||||||
for _, name := range groupsToAdd {
|
for _, name := range groupsToAdd {
|
||||||
group, exists := existedGroupsByName[name]
|
group, exists := existedGroupsByName[name]
|
||||||
if !exists {
|
if !exists {
|
||||||
group = &nbgroup.Group{
|
group = &nbgroup.Group{
|
||||||
ID: xid.New().String(),
|
ID: xid.New().String(),
|
||||||
Name: name,
|
AccountID: user.AccountID,
|
||||||
Issued: nbgroup.GroupIssuedJWT,
|
Name: name,
|
||||||
|
Issued: nbgroup.GroupIssuedJWT,
|
||||||
}
|
}
|
||||||
a.Groups[group.ID] = group
|
newGroupsToCreate = append(newGroupsToCreate, group)
|
||||||
}
|
}
|
||||||
if group.Issued == nbgroup.GroupIssuedJWT {
|
if group.Issued == nbgroup.GroupIssuedJWT {
|
||||||
newAutoGroups = append(newAutoGroups, group.ID)
|
newUserAutoGroups = append(newUserAutoGroups, group.ID)
|
||||||
modified = true
|
modified = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, id := range jwtGroupsMap {
|
for name, id := range jwtGroupsMap {
|
||||||
if !slices.Contains(groupsToRemove, name) {
|
if !slices.Contains(groupsToRemove, name) {
|
||||||
newAutoGroups = append(newAutoGroups, id)
|
newUserAutoGroups = append(newUserAutoGroups, id)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
modified = true
|
modified = true
|
||||||
}
|
}
|
||||||
user.AutoGroups = newAutoGroups
|
|
||||||
|
|
||||||
return modified
|
return modified, newUserAutoGroups, newGroupsToCreate, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserGroupsAddToPeers adds groups to all peers of user
|
// UserGroupsAddToPeers adds groups to all peers of user
|
||||||
@ -1261,37 +1263,36 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided.
|
// AccountExists checks if an account exists.
|
||||||
// If an accountID is provided, it checks if the account exists and returns it.
|
func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
|
||||||
// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID.
|
return am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountIDByUserID retrieves the account ID based on the userID provided.
|
||||||
|
// If user does have an account, it returns the user's account ID.
|
||||||
// If the user doesn't have an account, it creates one using the provided domain.
|
// If the user doesn't have an account, it creates one using the provided domain.
|
||||||
// Returns the account ID or an error if none is found or created.
|
// Returns the account ID or an error if none is found or created.
|
||||||
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
|
func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) {
|
||||||
if accountID != "" {
|
if userID == "" {
|
||||||
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
|
return "", status.Errorf(status.NotFound, "no valid userID provided")
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if !exists {
|
|
||||||
return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
|
|
||||||
}
|
|
||||||
return accountID, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if userID != "" {
|
accountID, err := am.Store.GetAccountIDByUserID(userID)
|
||||||
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
if err != nil {
|
||||||
if err != nil {
|
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||||
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
||||||
}
|
if err != nil {
|
||||||
|
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
||||||
|
}
|
||||||
|
|
||||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
|
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
}
|
||||||
|
return account.Id, nil
|
||||||
}
|
}
|
||||||
|
return "", err
|
||||||
return account.Id, nil
|
|
||||||
}
|
}
|
||||||
|
return accountID, nil
|
||||||
return "", status.Errorf(status.NotFound, "no valid userID or accountID provided")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func isNil(i idp.Manager) bool {
|
func isNil(i idp.Manager) bool {
|
||||||
@ -1764,7 +1765,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
|
if user.AccountID != accountID {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
|
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1795,6 +1796,10 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
|
|||||||
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if user.AccountID != accountID {
|
||||||
|
return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
if !user.IsServiceUser && claims.Invited {
|
if !user.IsServiceUser && claims.Invited {
|
||||||
err = am.redeemInvite(ctx, accountID, user.Id)
|
err = am.redeemInvite(ctx, accountID, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1802,7 +1807,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil {
|
if err = am.syncJWTGroups(ctx, accountID, claims); err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1811,7 +1816,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
|
|||||||
|
|
||||||
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||||
// and propagates changes to peers if group propagation is enabled.
|
// and propagates changes to peers if group propagation is enabled.
|
||||||
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error {
|
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error {
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -1822,69 +1827,134 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
}
|
}
|
||||||
|
|
||||||
if settings.JWTGroupsClaimName == "" {
|
if settings.JWTGroupsClaimName == "" {
|
||||||
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
|
log.WithContext(ctx).Debugf("JWT groups are enabled but no claim name is set")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Remove GetAccount after refactoring account peer's update
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||||
|
|
||||||
oldGroups := make([]string, len(user.AutoGroups))
|
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
copy(oldGroups, user.AutoGroups)
|
defer func() {
|
||||||
|
if unlockPeer != nil {
|
||||||
|
unlockPeer()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Update the account if group membership changes
|
var addNewGroups []string
|
||||||
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
|
var removeOldGroups []string
|
||||||
addNewGroups := difference(user.AutoGroups, oldGroups)
|
var hasChanges bool
|
||||||
removeOldGroups := difference(oldGroups, user.AutoGroups)
|
var user *User
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
if settings.GroupsPropagationEnabled {
|
user, err = am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
|
||||||
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
|
if err != nil {
|
||||||
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
|
return fmt.Errorf("error getting user: %w", err)
|
||||||
account.Network.IncSerial()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
groups, err := am.Store.GetAccountGroups(ctx, accountID)
|
||||||
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting account groups: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, jwtGroupsNames)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting JWT groups changes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasChanges = changed
|
||||||
|
// skip update if no changes
|
||||||
|
if !changed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
|
||||||
|
return fmt.Errorf("error saving groups: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
addNewGroups = difference(updatedAutoGroups, user.AutoGroups)
|
||||||
|
removeOldGroups = difference(user.AutoGroups, updatedAutoGroups)
|
||||||
|
|
||||||
|
user.AutoGroups = updatedAutoGroups
|
||||||
|
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
|
||||||
|
return fmt.Errorf("error saving user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Propagate changes to peers if group propagation is enabled
|
// Propagate changes to peers if group propagation is enabled
|
||||||
if settings.GroupsPropagationEnabled {
|
if settings.GroupsPropagationEnabled {
|
||||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
groups, err = transaction.GetAccountGroups(ctx, accountID)
|
||||||
if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
|
if err != nil {
|
||||||
am.updateAccountPeers(ctx, account)
|
return fmt.Errorf("error getting account groups: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groupsMap := make(map[string]*nbgroup.Group, len(groups))
|
||||||
|
for _, group := range groups {
|
||||||
|
groupsMap[group.ID] = group
|
||||||
|
}
|
||||||
|
|
||||||
|
peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting user peers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error modifying user peers in groups: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil {
|
||||||
|
return fmt.Errorf("error saving groups: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||||
|
return fmt.Errorf("error incrementing network serial: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
unlockPeer()
|
||||||
|
unlockPeer = nil
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasChanges {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, g := range addNewGroups {
|
||||||
|
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||||
|
} else {
|
||||||
|
meta := map[string]any{
|
||||||
|
"group": group.Name, "group_id": group.ID,
|
||||||
|
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
|
||||||
|
}
|
||||||
|
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, g := range removeOldGroups {
|
||||||
|
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||||
|
} else {
|
||||||
|
meta := map[string]any{
|
||||||
|
"group": group.Name, "group_id": group.ID,
|
||||||
|
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
|
||||||
|
}
|
||||||
|
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if settings.GroupsPropagationEnabled {
|
||||||
|
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting account: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, g := range addNewGroups {
|
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||||
if group := account.GetGroup(g); group != nil {
|
am.updateAccountPeers(ctx, account)
|
||||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
|
|
||||||
map[string]any{
|
|
||||||
"group": group.Name,
|
|
||||||
"group_id": group.ID,
|
|
||||||
"is_service_user": user.IsServiceUser,
|
|
||||||
"user_name": user.ServiceUserName})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, g := range removeOldGroups {
|
|
||||||
if group := account.GetGroup(g); group != nil {
|
|
||||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
|
|
||||||
map[string]any{
|
|
||||||
"group": group.Name,
|
|
||||||
"group_id": group.ID,
|
|
||||||
"is_service_user": user.IsServiceUser,
|
|
||||||
"user_name": user.ServiceUserName})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -1917,7 +1987,17 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
|
|||||||
// if Account ID is part of the claims
|
// if Account ID is part of the claims
|
||||||
// it means that we've already classified the domain and user has an account
|
// it means that we've already classified the domain and user has an account
|
||||||
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
||||||
return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
|
if claims.AccountId != "" {
|
||||||
|
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId)
|
||||||
|
}
|
||||||
|
return claims.AccountId, nil
|
||||||
|
}
|
||||||
|
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
|
||||||
} else if claims.AccountId != "" {
|
} else if claims.AccountId != "" {
|
||||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -2230,7 +2310,11 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
|
|||||||
routes := make(map[route.ID]*route.Route)
|
routes := make(map[route.ID]*route.Route)
|
||||||
setupKeys := map[string]*SetupKey{}
|
setupKeys := map[string]*SetupKey{}
|
||||||
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
||||||
users[userID] = NewOwnerUser(userID)
|
|
||||||
|
owner := NewOwnerUser(userID)
|
||||||
|
owner.AccountID = accountID
|
||||||
|
users[userID] = owner
|
||||||
|
|
||||||
dnsSettings := DNSSettings{
|
dnsSettings := DNSSettings{
|
||||||
DisabledManagementGroups: make([]string, 0),
|
DisabledManagementGroups: make([]string, 0),
|
||||||
}
|
}
|
||||||
@ -2298,12 +2382,17 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
|||||||
// separateGroups separates user's auto groups into non-JWT and JWT groups.
|
// separateGroups separates user's auto groups into non-JWT and JWT groups.
|
||||||
// Returns the list of standard auto groups and a map of JWT auto groups,
|
// Returns the list of standard auto groups and a map of JWT auto groups,
|
||||||
// where the keys are the group names and the values are the group IDs.
|
// where the keys are the group names and the values are the group IDs.
|
||||||
func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([]string, map[string]string) {
|
func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) {
|
||||||
newAutoGroups := make([]string, 0)
|
newAutoGroups := make([]string, 0)
|
||||||
jwtAutoGroups := make(map[string]string) // map of group name to group ID
|
jwtAutoGroups := make(map[string]string) // map of group name to group ID
|
||||||
|
|
||||||
|
allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups))
|
||||||
|
for _, group := range allGroups {
|
||||||
|
allGroupsMap[group.ID] = group
|
||||||
|
}
|
||||||
|
|
||||||
for _, id := range autoGroups {
|
for _, id := range autoGroups {
|
||||||
if group, ok := allGroups[id]; ok {
|
if group, ok := allGroupsMap[id]; ok {
|
||||||
if group.Issued == nbgroup.GroupIssuedJWT {
|
if group.Issued == nbgroup.GroupIssuedJWT {
|
||||||
jwtAutoGroups[group.Name] = id
|
jwtAutoGroups[group.Name] = id
|
||||||
} else {
|
} else {
|
||||||
@ -2311,5 +2400,6 @@ func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return newAutoGroups, jwtAutoGroups
|
return newAutoGroups, jwtAutoGroups
|
||||||
}
|
}
|
||||||
|
@ -633,7 +633,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain)
|
||||||
require.NoError(t, err, "create init user failed")
|
require.NoError(t, err, "create init user failed")
|
||||||
|
|
||||||
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
|
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
@ -671,17 +671,16 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
|||||||
userId := "user-id"
|
userId := "user-id"
|
||||||
domain := "test.domain"
|
domain := "test.domain"
|
||||||
|
|
||||||
initAccount := newAccountWithId(context.Background(), "", userId, domain)
|
_ = newAccountWithId(context.Background(), "", userId, domain)
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
accountID := initAccount.Id
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
|
||||||
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
|
|
||||||
require.NoError(t, err, "create init user failed")
|
require.NoError(t, err, "create init user failed")
|
||||||
// as initAccount was created without account id we have to take the id after account initialization
|
// as initAccount was created without account id we have to take the id after account initialization
|
||||||
// that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated
|
// that happens inside the GetAccountIDByUserID where the id is getting generated
|
||||||
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
|
||||||
initAccount, err = manager.Store.GetAccount(context.Background(), accountID)
|
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
require.NoError(t, err, "get init account failed")
|
require.NoError(t, err, "get init account failed")
|
||||||
|
|
||||||
claims := jwtclaims.AuthorizationClaims{
|
claims := jwtclaims.AuthorizationClaims{
|
||||||
@ -885,7 +884,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
|
func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -894,7 +893,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
|
|||||||
|
|
||||||
userId := "test_user"
|
userId := "test_user"
|
||||||
|
|
||||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "")
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -903,14 +902,13 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
|
exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID)
|
assert.True(t, exists, "expected to get existing account after creation using userid")
|
||||||
}
|
|
||||||
|
|
||||||
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "")
|
_, err = manager.GetAccountIDByUserID(context.Background(), "", "")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("expected an error when user and account IDs are empty")
|
t.Errorf("expected an error when user ID is empty")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1731,7 +1729,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||||
@ -1746,7 +1744,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
key, err := wgtypes.GenerateKey()
|
key, err := wgtypes.GenerateKey()
|
||||||
@ -1758,7 +1756,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
|||||||
})
|
})
|
||||||
require.NoError(t, err, "unable to add peer")
|
require.NoError(t, err, "unable to add peer")
|
||||||
|
|
||||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
@ -1804,7 +1802,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
key, err := wgtypes.GenerateKey()
|
key, err := wgtypes.GenerateKey()
|
||||||
@ -1832,7 +1830,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
@ -1852,7 +1850,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
_, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
key, err := wgtypes.GenerateKey()
|
key, err := wgtypes.GenerateKey()
|
||||||
@ -1864,7 +1862,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
|||||||
})
|
})
|
||||||
require.NoError(t, err, "unable to add peer")
|
require.NoError(t, err, "unable to add peer")
|
||||||
|
|
||||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||||
require.NoError(t, err, "unable to get the account")
|
require.NoError(t, err, "unable to get the account")
|
||||||
|
|
||||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||||
@ -1912,7 +1910,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
|||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||||
require.NoError(t, err, "unable to create an account")
|
require.NoError(t, err, "unable to create an account")
|
||||||
|
|
||||||
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||||
@ -1923,9 +1921,6 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
|||||||
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
||||||
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
||||||
|
|
||||||
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
|
|
||||||
require.NoError(t, err, "unable to get account by ID")
|
|
||||||
|
|
||||||
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
|
||||||
require.NoError(t, err, "unable to get account settings")
|
require.NoError(t, err, "unable to get account settings")
|
||||||
|
|
||||||
@ -2261,8 +2256,12 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_SetJWTGroups(t *testing.T) {
|
func TestAccount_SetJWTGroups(t *testing.T) {
|
||||||
|
manager, err := createManager(t)
|
||||||
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
// create a new account
|
// create a new account
|
||||||
account := &Account{
|
account := &Account{
|
||||||
|
Id: "accountID",
|
||||||
Peers: map[string]*nbpeer.Peer{
|
Peers: map[string]*nbpeer.Peer{
|
||||||
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
|
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
|
||||||
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
|
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
|
||||||
@ -2273,62 +2272,120 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
Groups: map[string]*group.Group{
|
Groups: map[string]*group.Group{
|
||||||
"group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}},
|
"group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}},
|
||||||
},
|
},
|
||||||
Settings: &Settings{GroupsPropagationEnabled: true},
|
Settings: &Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"},
|
||||||
Users: map[string]*User{
|
Users: map[string]*User{
|
||||||
"user1": {Id: "user1"},
|
"user1": {Id: "user1", AccountID: "accountID"},
|
||||||
"user2": {Id: "user2"},
|
"user2": {Id: "user2", AccountID: "accountID"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
|
||||||
|
|
||||||
t.Run("empty jwt groups", func(t *testing.T) {
|
t.Run("empty jwt groups", func(t *testing.T) {
|
||||||
updated := account.SetJWTGroups("user1", []string{})
|
claims := jwtclaims.AuthorizationClaims{
|
||||||
assert.False(t, updated, "account should not be updated")
|
UserId: "user1",
|
||||||
assert.Empty(t, account.Users["user1"].AutoGroups, "auto groups must be empty")
|
Raw: jwt.MapClaims{"groups": []interface{}{}},
|
||||||
|
}
|
||||||
|
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||||
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
|
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
|
||||||
|
assert.NoError(t, err, "unable to get user")
|
||||||
|
assert.Empty(t, user.AutoGroups, "auto groups must be empty")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("jwt match existing api group", func(t *testing.T) {
|
t.Run("jwt match existing api group", func(t *testing.T) {
|
||||||
updated := account.SetJWTGroups("user1", []string{"group1"})
|
claims := jwtclaims.AuthorizationClaims{
|
||||||
assert.False(t, updated, "account should not be updated")
|
UserId: "user1",
|
||||||
assert.Equal(t, 0, len(account.Users["user1"].AutoGroups))
|
Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
|
||||||
assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued")
|
}
|
||||||
|
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||||
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
|
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
|
||||||
|
assert.NoError(t, err, "unable to get user")
|
||||||
|
assert.Len(t, user.AutoGroups, 0)
|
||||||
|
|
||||||
|
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
||||||
|
assert.NoError(t, err, "unable to get group")
|
||||||
|
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("jwt match existing api group in user auto groups", func(t *testing.T) {
|
t.Run("jwt match existing api group in user auto groups", func(t *testing.T) {
|
||||||
account.Users["user1"].AutoGroups = []string{"group1"}
|
account.Users["user1"].AutoGroups = []string{"group1"}
|
||||||
|
assert.NoError(t, manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, account.Users["user1"]))
|
||||||
|
|
||||||
updated := account.SetJWTGroups("user1", []string{"group1"})
|
claims := jwtclaims.AuthorizationClaims{
|
||||||
assert.False(t, updated, "account should not be updated")
|
UserId: "user1",
|
||||||
assert.Equal(t, 1, len(account.Users["user1"].AutoGroups))
|
Raw: jwt.MapClaims{"groups": []interface{}{"group1"}},
|
||||||
assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued")
|
}
|
||||||
|
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||||
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
|
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
|
||||||
|
assert.NoError(t, err, "unable to get user")
|
||||||
|
assert.Len(t, user.AutoGroups, 1)
|
||||||
|
|
||||||
|
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
||||||
|
assert.NoError(t, err, "unable to get group")
|
||||||
|
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("add jwt group", func(t *testing.T) {
|
t.Run("add jwt group", func(t *testing.T) {
|
||||||
updated := account.SetJWTGroups("user1", []string{"group1", "group2"})
|
claims := jwtclaims.AuthorizationClaims{
|
||||||
assert.True(t, updated, "account should be updated")
|
UserId: "user1",
|
||||||
assert.Len(t, account.Groups, 2, "new group should be added")
|
Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group2"}},
|
||||||
assert.Len(t, account.Users["user1"].AutoGroups, 2, "new group should be added")
|
}
|
||||||
assert.Contains(t, account.Groups, account.Users["user1"].AutoGroups[0], "groups must contain group2 from user groups")
|
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||||
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
|
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
|
||||||
|
assert.NoError(t, err, "unable to get user")
|
||||||
|
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("existed group not update", func(t *testing.T) {
|
t.Run("existed group not update", func(t *testing.T) {
|
||||||
updated := account.SetJWTGroups("user1", []string{"group2"})
|
claims := jwtclaims.AuthorizationClaims{
|
||||||
assert.False(t, updated, "account should not be updated")
|
UserId: "user1",
|
||||||
assert.Len(t, account.Groups, 2, "groups count should not be changed")
|
Raw: jwt.MapClaims{"groups": []interface{}{"group2"}},
|
||||||
|
}
|
||||||
|
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||||
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
|
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
|
||||||
|
assert.NoError(t, err, "unable to get user")
|
||||||
|
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("add new group", func(t *testing.T) {
|
t.Run("add new group", func(t *testing.T) {
|
||||||
updated := account.SetJWTGroups("user2", []string{"group1", "group3"})
|
claims := jwtclaims.AuthorizationClaims{
|
||||||
assert.True(t, updated, "account should be updated")
|
UserId: "user2",
|
||||||
assert.Len(t, account.Groups, 3, "new group should be added")
|
Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group3"}},
|
||||||
assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added")
|
}
|
||||||
assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups")
|
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||||
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
|
groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, groups, 3, "new group3 should be added")
|
||||||
|
|
||||||
|
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user2")
|
||||||
|
assert.NoError(t, err, "unable to get user")
|
||||||
|
assert.Len(t, user.AutoGroups, 1, "new group should be added")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("remove all JWT groups", func(t *testing.T) {
|
t.Run("remove all JWT groups", func(t *testing.T) {
|
||||||
updated := account.SetJWTGroups("user1", []string{})
|
claims := jwtclaims.AuthorizationClaims{
|
||||||
assert.True(t, updated, "account should be updated")
|
UserId: "user1",
|
||||||
assert.Len(t, account.Users["user1"].AutoGroups, 1, "only non-JWT groups should remain")
|
Raw: jwt.MapClaims{"groups": []interface{}{}},
|
||||||
assert.Contains(t, account.Users["user1"].AutoGroups, "group1", " group1 should still be present")
|
}
|
||||||
|
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||||
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
|
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
|
||||||
|
assert.NoError(t, err, "unable to get user")
|
||||||
|
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
|
||||||
|
assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2428,7 +2485,7 @@ func createManager(t TB) (*DefaultAccountManager, error) {
|
|||||||
func createStore(t TB) (Store, error) {
|
func createStore(t TB) (Store, error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
dataDir := t.TempDir()
|
dataDir := t.TempDir()
|
||||||
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
|
store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -212,7 +212,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
|||||||
func createDNSStore(t *testing.T) (Store, error) {
|
func createDNSStore(t *testing.T) (Store, error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
dataDir := t.TempDir()
|
dataDir := t.TempDir()
|
||||||
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
|
store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -2,24 +2,18 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/dns"
|
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -42,167 +36,9 @@ type FileStore struct {
|
|||||||
mux sync.Mutex `json:"-"`
|
mux sync.Mutex `json:"-"`
|
||||||
storeFile string `json:"-"`
|
storeFile string `json:"-"`
|
||||||
|
|
||||||
// sync.Mutex indexed by resource ID
|
|
||||||
resourceLocks sync.Map `json:"-"`
|
|
||||||
globalAccountLock sync.Mutex `json:"-"`
|
|
||||||
|
|
||||||
metrics telemetry.AppMetrics `json:"-"`
|
metrics telemetry.AppMetrics `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error {
|
|
||||||
return f(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)]
|
|
||||||
if !ok {
|
|
||||||
return status.NewSetupKeyNotFoundError()
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
account.SetupKeys[setupKeyID].UsedTimes++
|
|
||||||
|
|
||||||
return s.SaveAccount(ctx, account)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
allGroup, err := account.GetGroupAll()
|
|
||||||
if err != nil || allGroup == nil {
|
|
||||||
return errors.New("all group not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
allGroup.Peers = append(allGroup.Peers, peerID)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, ok := s.Accounts[peer.AccountID]
|
|
||||||
if !ok {
|
|
||||||
return status.NewAccountNotFoundError(peer.AccountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Peers[peer.ID] = peer
|
|
||||||
return s.SaveAccount(ctx, account)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, ok := s.Accounts[accountId]
|
|
||||||
if !ok {
|
|
||||||
return status.NewAccountNotFoundError(accountId)
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Network.Serial++
|
|
||||||
|
|
||||||
return s.SaveAccount(ctx, account)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)]
|
|
||||||
if !ok {
|
|
||||||
return nil, status.NewSetupKeyNotFoundError()
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
setupKey, ok := account.SetupKeys[key]
|
|
||||||
if !ok {
|
|
||||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
return setupKey, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var takenIps []net.IP
|
|
||||||
for _, existingPeer := range account.Peers {
|
|
||||||
takenIps = append(takenIps, existingPeer.IP)
|
|
||||||
}
|
|
||||||
|
|
||||||
return takenIps, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
existingLabels := []string{}
|
|
||||||
for _, peer := range account.Peers {
|
|
||||||
if peer.DNSLabel != "" {
|
|
||||||
existingLabels = append(existingLabels, peer.DNSLabel)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return existingLabels, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return account.Network, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type StoredAccount struct{}
|
|
||||||
|
|
||||||
// NewFileStore restores a store from the file located in the datadir
|
// NewFileStore restores a store from the file located in the datadir
|
||||||
func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
|
func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
|
||||||
fs, err := restore(ctx, filepath.Join(dataDir, storeFileName))
|
fs, err := restore(ctx, filepath.Join(dataDir, storeFileName))
|
||||||
@ -213,25 +49,6 @@ func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetr
|
|||||||
return fs, nil
|
return fs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir
|
|
||||||
func NewFilestoreFromSqliteStore(ctx context.Context, sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
|
|
||||||
store, err := NewFileStore(ctx, dataDir, metrics)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = store.SaveInstallationID(ctx, sqlStore.GetInstallationID())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, account := range sqlStore.GetAllAccounts(ctx) {
|
|
||||||
store.Accounts[account.Id] = account
|
|
||||||
}
|
|
||||||
|
|
||||||
return store, store.persist(ctx, store.storeFile)
|
|
||||||
}
|
|
||||||
|
|
||||||
// restore the state of the store from the file.
|
// restore the state of the store from the file.
|
||||||
// Creates a new empty store file if doesn't exist
|
// Creates a new empty store file if doesn't exist
|
||||||
func restore(ctx context.Context, file string) (*FileStore, error) {
|
func restore(ctx context.Context, file string) (*FileStore, error) {
|
||||||
@ -240,7 +57,6 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
|||||||
s := &FileStore{
|
s := &FileStore{
|
||||||
Accounts: make(map[string]*Account),
|
Accounts: make(map[string]*Account),
|
||||||
mux: sync.Mutex{},
|
mux: sync.Mutex{},
|
||||||
globalAccountLock: sync.Mutex{},
|
|
||||||
SetupKeyID2AccountID: make(map[string]string),
|
SetupKeyID2AccountID: make(map[string]string),
|
||||||
PeerKeyID2AccountID: make(map[string]string),
|
PeerKeyID2AccountID: make(map[string]string),
|
||||||
UserID2AccountID: make(map[string]string),
|
UserID2AccountID: make(map[string]string),
|
||||||
@ -416,252 +232,6 @@ func (s *FileStore) persist(ctx context.Context, file string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
|
|
||||||
func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
|
|
||||||
log.WithContext(ctx).Debugf("acquiring global lock")
|
|
||||||
start := time.Now()
|
|
||||||
s.globalAccountLock.Lock()
|
|
||||||
|
|
||||||
unlock = func() {
|
|
||||||
s.globalAccountLock.Unlock()
|
|
||||||
log.WithContext(ctx).Debugf("released global lock in %v", time.Since(start))
|
|
||||||
}
|
|
||||||
|
|
||||||
took := time.Since(start)
|
|
||||||
log.WithContext(ctx).Debugf("took %v to acquire global lock", took)
|
|
||||||
if s.metrics != nil {
|
|
||||||
s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took)
|
|
||||||
}
|
|
||||||
|
|
||||||
return unlock
|
|
||||||
}
|
|
||||||
|
|
||||||
// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
|
|
||||||
func (s *FileStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
|
||||||
log.WithContext(ctx).Debugf("acquiring lock for ID %s", uniqueID)
|
|
||||||
start := time.Now()
|
|
||||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.Mutex{})
|
|
||||||
mtx := value.(*sync.Mutex)
|
|
||||||
mtx.Lock()
|
|
||||||
|
|
||||||
unlock = func() {
|
|
||||||
mtx.Unlock()
|
|
||||||
log.WithContext(ctx).Debugf("released lock for ID %s in %v", uniqueID, time.Since(start))
|
|
||||||
}
|
|
||||||
|
|
||||||
return unlock
|
|
||||||
}
|
|
||||||
|
|
||||||
// AcquireReadLockByUID acquires an ID lock for reading a resource and returns a function that releases the lock
|
|
||||||
// This method is still returns a write lock as file store can't handle read locks
|
|
||||||
func (s *FileStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
|
||||||
return s.AcquireWriteLockByUID(ctx, uniqueID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
if account.Id == "" {
|
|
||||||
return status.Errorf(status.InvalidArgument, "account id should not be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
accountCopy := account.Copy()
|
|
||||||
|
|
||||||
s.Accounts[accountCopy.Id] = accountCopy
|
|
||||||
|
|
||||||
// todo check that account.Id and keyId are not exist already
|
|
||||||
// because if keyId exists for other accounts this can be bad
|
|
||||||
for keyID := range accountCopy.SetupKeys {
|
|
||||||
s.SetupKeyID2AccountID[strings.ToUpper(keyID)] = accountCopy.Id
|
|
||||||
}
|
|
||||||
|
|
||||||
// enforce peer to account index and delete peer to route indexes for rebuild
|
|
||||||
for _, peer := range accountCopy.Peers {
|
|
||||||
s.PeerKeyID2AccountID[peer.Key] = accountCopy.Id
|
|
||||||
s.PeerID2AccountID[peer.ID] = accountCopy.Id
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, user := range accountCopy.Users {
|
|
||||||
s.UserID2AccountID[user.Id] = accountCopy.Id
|
|
||||||
for _, pat := range user.PATs {
|
|
||||||
s.TokenID2UserID[pat.ID] = user.Id
|
|
||||||
s.HashedPAT2TokenID[pat.HashedToken] = pat.ID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if accountCopy.DomainCategory == PrivateCategory && accountCopy.IsDomainPrimaryAccount {
|
|
||||||
s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.persist(ctx, s.storeFile)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) DeleteAccount(ctx context.Context, account *Account) error {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
if account.Id == "" {
|
|
||||||
return status.Errorf(status.InvalidArgument, "account id should not be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
for keyID := range account.SetupKeys {
|
|
||||||
delete(s.SetupKeyID2AccountID, strings.ToUpper(keyID))
|
|
||||||
}
|
|
||||||
|
|
||||||
// enforce peer to account index and delete peer to route indexes for rebuild
|
|
||||||
for _, peer := range account.Peers {
|
|
||||||
delete(s.PeerKeyID2AccountID, peer.Key)
|
|
||||||
delete(s.PeerID2AccountID, peer.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, user := range account.Users {
|
|
||||||
for _, pat := range user.PATs {
|
|
||||||
delete(s.TokenID2UserID, pat.ID)
|
|
||||||
delete(s.HashedPAT2TokenID, pat.HashedToken)
|
|
||||||
}
|
|
||||||
delete(s.UserID2AccountID, user.Id)
|
|
||||||
}
|
|
||||||
|
|
||||||
if account.DomainCategory == PrivateCategory && account.IsDomainPrimaryAccount {
|
|
||||||
delete(s.PrivateDomain2AccountID, account.Domain)
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(s.Accounts, account.Id)
|
|
||||||
|
|
||||||
return s.persist(ctx, s.storeFile)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteHashedPAT2TokenIDIndex removes an entry from the indexing map HashedPAT2TokenID
|
|
||||||
func (s *FileStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
delete(s.HashedPAT2TokenID, hashedToken)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteTokenID2UserIDIndex removes an entry from the indexing map TokenID2UserID
|
|
||||||
func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
delete(s.TokenID2UserID, tokenID)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountByPrivateDomain returns account by private domain
|
|
||||||
func (s *FileStore) GetAccountByPrivateDomain(_ context.Context, domain string) (*Account, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
accountID, ok := s.PrivateDomain2AccountID[strings.ToLower(domain)]
|
|
||||||
if !ok {
|
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return account.Copy(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountBySetupKey returns account by setup key id
|
|
||||||
func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*Account, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
|
||||||
if !ok {
|
|
||||||
return nil, status.NewSetupKeyNotFoundError()
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return account.Copy(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetTokenIDByHashedToken returns the id of a personal access token by its hashed secret
|
|
||||||
func (s *FileStore) GetTokenIDByHashedToken(_ context.Context, token string) (string, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
tokenID, ok := s.HashedPAT2TokenID[token]
|
|
||||||
if !ok {
|
|
||||||
return "", status.Errorf(status.NotFound, "tokenID not found: provided token doesn't exists")
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokenID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserByTokenID returns a User object a tokenID belongs to
|
|
||||||
func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
userID, ok := s.TokenID2UserID[tokenID]
|
|
||||||
if !ok {
|
|
||||||
return nil, status.Errorf(status.NotFound, "user not found: provided tokenID doesn't exists")
|
|
||||||
}
|
|
||||||
|
|
||||||
accountID, ok := s.UserID2AccountID[userID]
|
|
||||||
if !ok {
|
|
||||||
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return account.Users[userID].Copy(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) {
|
|
||||||
accountID, ok := s.UserID2AccountID[userID]
|
|
||||||
if !ok {
|
|
||||||
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
user := account.Users[userID].Copy()
|
|
||||||
pat := make([]PersonalAccessToken, 0, len(user.PATs))
|
|
||||||
for _, token := range user.PATs {
|
|
||||||
if token != nil {
|
|
||||||
pat = append(pat, *token)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
user.PATsG = pat
|
|
||||||
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) {
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
groupsSlice := make([]*nbgroup.Group, 0, len(account.Groups))
|
|
||||||
|
|
||||||
for _, group := range account.Groups {
|
|
||||||
groupsSlice = append(groupsSlice, group)
|
|
||||||
}
|
|
||||||
|
|
||||||
return groupsSlice, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllAccounts returns all accounts
|
// GetAllAccounts returns all accounts
|
||||||
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
|
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
@ -673,278 +243,6 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
|
|||||||
return all
|
return all
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAccount returns a reference to the Account. Should not return a copy.
|
|
||||||
func (s *FileStore) getAccount(accountID string) (*Account, error) {
|
|
||||||
account, ok := s.Accounts[accountID]
|
|
||||||
if !ok {
|
|
||||||
return nil, status.NewAccountNotFoundError(accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return account, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccount returns an account for ID
|
|
||||||
func (s *FileStore) GetAccount(_ context.Context, accountID string) (*Account, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return account.Copy(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountByUser returns a user account
|
|
||||||
func (s *FileStore) GetAccountByUser(_ context.Context, userID string) (*Account, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
accountID, ok := s.UserID2AccountID[userID]
|
|
||||||
if !ok {
|
|
||||||
return nil, status.NewUserNotFoundError(userID)
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return account.Copy(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountByPeerID returns an account for a given peer ID
|
|
||||||
func (s *FileStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
accountID, ok := s.PeerID2AccountID[peerID]
|
|
||||||
if !ok {
|
|
||||||
return nil, status.Errorf(status.NotFound, "provided peer ID doesn't exists %s", peerID)
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// this protection is needed because when we delete a peer, we don't really remove index peerID -> accountID.
|
|
||||||
// check Account.Peers for a match
|
|
||||||
if _, ok := account.Peers[peerID]; !ok {
|
|
||||||
delete(s.PeerID2AccountID, peerID)
|
|
||||||
log.WithContext(ctx).Warnf("removed stale peerID %s to accountID %s index", peerID, accountID)
|
|
||||||
return nil, status.NewPeerNotFoundError(peerID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return account.Copy(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountByPeerPubKey returns an account for a given peer WireGuard public key
|
|
||||||
func (s *FileStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
accountID, ok := s.PeerKeyID2AccountID[peerKey]
|
|
||||||
if !ok {
|
|
||||||
return nil, status.NewPeerNotFoundError(peerKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// this protection is needed because when we delete a peer, we don't really remove index peerKey -> accountID.
|
|
||||||
// check Account.Peers for a match
|
|
||||||
stale := true
|
|
||||||
for _, peer := range account.Peers {
|
|
||||||
if peer.Key == peerKey {
|
|
||||||
stale = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if stale {
|
|
||||||
delete(s.PeerKeyID2AccountID, peerKey)
|
|
||||||
log.WithContext(ctx).Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID)
|
|
||||||
return nil, status.NewPeerNotFoundError(peerKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
return account.Copy(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetAccountIDByPeerPubKey(_ context.Context, peerKey string) (string, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
accountID, ok := s.PeerKeyID2AccountID[peerKey]
|
|
||||||
if !ok {
|
|
||||||
return "", status.NewPeerNotFoundError(peerKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
return accountID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetAccountIDByUserID(userID string) (string, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
accountID, ok := s.UserID2AccountID[userID]
|
|
||||||
if !ok {
|
|
||||||
return "", status.NewUserNotFoundError(userID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return accountID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (string, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
|
||||||
if !ok {
|
|
||||||
return "", status.NewSetupKeyNotFoundError()
|
|
||||||
}
|
|
||||||
|
|
||||||
return accountID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
accountID, ok := s.PeerKeyID2AccountID[peerKey]
|
|
||||||
if !ok {
|
|
||||||
return nil, status.NewPeerNotFoundError(peerKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, peer := range account.Peers {
|
|
||||||
if peer.Key == peerKey {
|
|
||||||
return peer.Copy(), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, status.NewPeerNotFoundError(peerKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return account.Settings.Copy(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetInstallationID returns the installation ID from the store
|
|
||||||
func (s *FileStore) GetInstallationID() string {
|
|
||||||
return s.InstallationID
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveInstallationID saves the installation ID
|
|
||||||
func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
s.InstallationID = ID
|
|
||||||
|
|
||||||
return s.persist(ctx, s.storeFile)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SavePeer saves the peer in the account
|
|
||||||
func (s *FileStore) SavePeer(_ context.Context, accountID string, peer *nbpeer.Peer) error {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
newPeer := peer.Copy()
|
|
||||||
|
|
||||||
account.Peers[peer.ID] = newPeer
|
|
||||||
|
|
||||||
s.PeerKeyID2AccountID[peer.Key] = accountID
|
|
||||||
s.PeerID2AccountID[peer.ID] = accountID
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
|
|
||||||
// PeerStatus will be saved eventually when some other changes occur.
|
|
||||||
func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peer := account.Peers[peerID]
|
|
||||||
if peer == nil {
|
|
||||||
return status.Errorf(status.NotFound, "peer %s not found", peerID)
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.Status = &peerStatus
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SavePeerLocation stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
|
|
||||||
// Peer.Location will be saved eventually when some other changes occur.
|
|
||||||
func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peer := account.Peers[peerWithLocation.ID]
|
|
||||||
if peer == nil {
|
|
||||||
return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.Location = peerWithLocation.Location
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things.
|
|
||||||
func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peer := account.Users[userID]
|
|
||||||
if peer == nil {
|
|
||||||
return status.Errorf(status.NotFound, "user %s not found", userID)
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.LastLogin = lastLogin
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetPostureCheckByChecksDefinition(_ string, _ *posture.ChecksDefinition) (*posture.Checks, error) {
|
|
||||||
return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the FileStore persisting data to disk
|
// Close the FileStore persisting data to disk
|
||||||
func (s *FileStore) Close(ctx context.Context) error {
|
func (s *FileStore) Close(ctx context.Context) error {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
@ -959,86 +257,3 @@ func (s *FileStore) Close(ctx context.Context) error {
|
|||||||
func (s *FileStore) GetStoreEngine() StoreEngine {
|
func (s *FileStore) GetStoreEngine() StoreEngine {
|
||||||
return FileStoreEngine
|
return FileStoreEngine
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error {
|
|
||||||
return status.Errorf(status.Internal, "SaveUsers is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error {
|
|
||||||
return status.Errorf(status.Internal, "SaveGroups is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) {
|
|
||||||
return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return account.Domain, account.DomainCategory, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AccountExists checks whether an account exists by the given ID.
|
|
||||||
func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) {
|
|
||||||
_, exists := s.Accounts[id]
|
|
||||||
return exists, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetAccountDNSSettings(_ context.Context, _ LockingStrength, _ string) (*DNSSettings, error) {
|
|
||||||
return nil, status.Errorf(status.Internal, "GetAccountDNSSettings is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetGroupByID(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
|
|
||||||
return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
|
|
||||||
return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) {
|
|
||||||
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) {
|
|
||||||
return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) {
|
|
||||||
return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) {
|
|
||||||
return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) {
|
|
||||||
return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, _ string) (*route.Route, error) {
|
|
||||||
return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) {
|
|
||||||
return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*SetupKey, error) {
|
|
||||||
return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) {
|
|
||||||
return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) {
|
|
||||||
return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented")
|
|
||||||
}
|
|
||||||
|
@ -1,655 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/sha256"
|
|
||||||
"net"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/group"
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
type accounts struct {
|
|
||||||
Accounts map[string]*Account
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStalePeerIndices(t *testing.T) {
|
|
||||||
storeDir := t.TempDir()
|
|
||||||
|
|
||||||
err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
peerID := "some_peer"
|
|
||||||
peerKey := "some_peer_key"
|
|
||||||
account.Peers[peerID] = &nbpeer.Peer{
|
|
||||||
ID: peerID,
|
|
||||||
Key: peerKey,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = store.SaveAccount(context.Background(), account)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
account.DeletePeer(peerID)
|
|
||||||
|
|
||||||
err = store.SaveAccount(context.Background(), account)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, err = store.GetAccountByPeerID(context.Background(), peerID)
|
|
||||||
require.Error(t, err, "expecting to get an error when found stale index")
|
|
||||||
|
|
||||||
_, err = store.GetAccountByPeerPubKey(context.Background(), peerKey)
|
|
||||||
require.Error(t, err, "expecting to get an error when found stale index")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewStore(t *testing.T) {
|
|
||||||
store := newStore(t)
|
|
||||||
defer store.Close(context.Background())
|
|
||||||
|
|
||||||
if store.Accounts == nil || len(store.Accounts) != 0 {
|
|
||||||
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
|
||||||
}
|
|
||||||
|
|
||||||
if store.SetupKeyID2AccountID == nil || len(store.SetupKeyID2AccountID) != 0 {
|
|
||||||
t.Errorf("expected to create a new empty SetupKeyID2AccountID map when creating a new FileStore")
|
|
||||||
}
|
|
||||||
|
|
||||||
if store.PeerKeyID2AccountID == nil || len(store.PeerKeyID2AccountID) != 0 {
|
|
||||||
t.Errorf("expected to create a new empty PeerKeyID2AccountID map when creating a new FileStore")
|
|
||||||
}
|
|
||||||
|
|
||||||
if store.UserID2AccountID == nil || len(store.UserID2AccountID) != 0 {
|
|
||||||
t.Errorf("expected to create a new empty UserID2AccountID map when creating a new FileStore")
|
|
||||||
}
|
|
||||||
|
|
||||||
if store.HashedPAT2TokenID == nil || len(store.HashedPAT2TokenID) != 0 {
|
|
||||||
t.Errorf("expected to create a new empty HashedPAT2TokenID map when creating a new FileStore")
|
|
||||||
}
|
|
||||||
|
|
||||||
if store.TokenID2UserID == nil || len(store.TokenID2UserID) != 0 {
|
|
||||||
t.Errorf("expected to create a new empty TokenID2UserID map when creating a new FileStore")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSaveAccount(t *testing.T) {
|
|
||||||
store := newStore(t)
|
|
||||||
defer store.Close(context.Background())
|
|
||||||
|
|
||||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
|
||||||
setupKey := GenerateDefaultSetupKey()
|
|
||||||
account.SetupKeys[setupKey.Key] = setupKey
|
|
||||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
|
||||||
Key: "peerkey",
|
|
||||||
SetupKey: "peerkeysetupkey",
|
|
||||||
IP: net.IP{127, 0, 0, 1},
|
|
||||||
Meta: nbpeer.PeerSystemMeta{},
|
|
||||||
Name: "peer name",
|
|
||||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveAccount should trigger persist
|
|
||||||
err := store.SaveAccount(context.Background(), account)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if store.Accounts[account.Id] == nil {
|
|
||||||
t.Errorf("expecting Account to be stored after SaveAccount()")
|
|
||||||
}
|
|
||||||
|
|
||||||
if store.PeerKeyID2AccountID["peerkey"] == "" {
|
|
||||||
t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount()")
|
|
||||||
}
|
|
||||||
|
|
||||||
if store.UserID2AccountID["testuser"] == "" {
|
|
||||||
t.Errorf("expecting UserID2AccountID index updated after SaveAccount()")
|
|
||||||
}
|
|
||||||
|
|
||||||
if store.SetupKeyID2AccountID[setupKey.Key] == "" {
|
|
||||||
t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount()")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDeleteAccount(t *testing.T) {
|
|
||||||
storeDir := t.TempDir()
|
|
||||||
storeFile := filepath.Join(storeDir, "store.json")
|
|
||||||
err := util.CopyFileContents("testdata/store.json", storeFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer store.Close(context.Background())
|
|
||||||
|
|
||||||
var account *Account
|
|
||||||
for _, a := range store.Accounts {
|
|
||||||
account = a
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NotNil(t, account, "failed to restore a FileStore file and get at least one account")
|
|
||||||
|
|
||||||
err = store.DeleteAccount(context.Background(), account)
|
|
||||||
require.NoError(t, err, "failed to delete account, error: %v", err)
|
|
||||||
|
|
||||||
_, ok := store.Accounts[account.Id]
|
|
||||||
require.False(t, ok, "failed to delete account")
|
|
||||||
|
|
||||||
for id := range account.Users {
|
|
||||||
_, ok := store.UserID2AccountID[id]
|
|
||||||
assert.False(t, ok, "failed to delete UserID2AccountID index")
|
|
||||||
for _, pat := range account.Users[id].PATs {
|
|
||||||
_, ok := store.HashedPAT2TokenID[pat.HashedToken]
|
|
||||||
assert.False(t, ok, "failed to delete HashedPAT2TokenID index")
|
|
||||||
_, ok = store.TokenID2UserID[pat.ID]
|
|
||||||
assert.False(t, ok, "failed to delete TokenID2UserID index")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, p := range account.Peers {
|
|
||||||
_, ok := store.PeerKeyID2AccountID[p.Key]
|
|
||||||
assert.False(t, ok, "failed to delete PeerKeyID2AccountID index")
|
|
||||||
_, ok = store.PeerID2AccountID[p.ID]
|
|
||||||
assert.False(t, ok, "failed to delete PeerID2AccountID index")
|
|
||||||
}
|
|
||||||
|
|
||||||
for id := range account.SetupKeys {
|
|
||||||
_, ok := store.SetupKeyID2AccountID[id]
|
|
||||||
assert.False(t, ok, "failed to delete SetupKeyID2AccountID index")
|
|
||||||
}
|
|
||||||
|
|
||||||
_, ok = store.PrivateDomain2AccountID[account.Domain]
|
|
||||||
assert.False(t, ok, "failed to delete PrivateDomain2AccountID index")
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStore(t *testing.T) {
|
|
||||||
store := newStore(t)
|
|
||||||
defer store.Close(context.Background())
|
|
||||||
|
|
||||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
|
||||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
|
||||||
Key: "peerkey",
|
|
||||||
SetupKey: "peerkeysetupkey",
|
|
||||||
IP: net.IP{127, 0, 0, 1},
|
|
||||||
Meta: nbpeer.PeerSystemMeta{},
|
|
||||||
Name: "peer name",
|
|
||||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
|
||||||
}
|
|
||||||
account.Groups["all"] = &group.Group{
|
|
||||||
ID: "all",
|
|
||||||
Name: "all",
|
|
||||||
Peers: []string{"testpeer"},
|
|
||||||
}
|
|
||||||
account.Policies = append(account.Policies, &Policy{
|
|
||||||
ID: "all",
|
|
||||||
Name: "all",
|
|
||||||
Enabled: true,
|
|
||||||
Rules: []*PolicyRule{
|
|
||||||
{
|
|
||||||
ID: "all",
|
|
||||||
Name: "all",
|
|
||||||
Sources: []string{"all"},
|
|
||||||
Destinations: []string{"all"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
account.Policies = append(account.Policies, &Policy{
|
|
||||||
ID: "dmz",
|
|
||||||
Name: "dmz",
|
|
||||||
Enabled: true,
|
|
||||||
Rules: []*PolicyRule{
|
|
||||||
{
|
|
||||||
ID: "dmz",
|
|
||||||
Name: "dmz",
|
|
||||||
Enabled: true,
|
|
||||||
Sources: []string{"all"},
|
|
||||||
Destinations: []string{"all"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
// SaveAccount should trigger persist
|
|
||||||
err := store.SaveAccount(context.Background(), account)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
restored, err := NewFileStore(context.Background(), store.storeFile, nil)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
restoredAccount := restored.Accounts[account.Id]
|
|
||||||
if restoredAccount == nil {
|
|
||||||
t.Errorf("failed to restore a FileStore file - missing Account %s", account.Id)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if restoredAccount.Peers["testpeer"] == nil {
|
|
||||||
t.Errorf("failed to restore a FileStore file - missing Peer testpeer")
|
|
||||||
}
|
|
||||||
|
|
||||||
if restoredAccount.CreatedBy != "testuser" {
|
|
||||||
t.Errorf("failed to restore a FileStore file - missing Account CreatedBy")
|
|
||||||
}
|
|
||||||
|
|
||||||
if restoredAccount.Users["testuser"] == nil {
|
|
||||||
t.Errorf("failed to restore a FileStore file - missing User testuser")
|
|
||||||
}
|
|
||||||
|
|
||||||
if restoredAccount.Network == nil {
|
|
||||||
t.Errorf("failed to restore a FileStore file - missing Network")
|
|
||||||
}
|
|
||||||
|
|
||||||
if restoredAccount.Groups["all"] == nil {
|
|
||||||
t.Errorf("failed to restore a FileStore file - missing Group all")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(restoredAccount.Policies) != 2 {
|
|
||||||
t.Errorf("failed to restore a FileStore file - missing Policies")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, account.Policies[0], restoredAccount.Policies[0], "failed to restore a FileStore file - missing Policy all")
|
|
||||||
assert.Equal(t, account.Policies[1], restoredAccount.Policies[1], "failed to restore a FileStore file - missing Policy dmz")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRestore(t *testing.T) {
|
|
||||||
storeDir := t.TempDir()
|
|
||||||
|
|
||||||
err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
|
|
||||||
|
|
||||||
require.NotNil(t, account, "failed to restore a FileStore file - missing account bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
|
||||||
|
|
||||||
require.NotNil(t, account.Users["edafee4e-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore file - missing Account User edafee4e-63fb-11ec-90d6-0242ac120003")
|
|
||||||
|
|
||||||
require.NotNil(t, account.Users["f4f6d672-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore file - missing Account User f4f6d672-63fb-11ec-90d6-0242ac120003")
|
|
||||||
|
|
||||||
require.NotNil(t, account.Network, "failed to restore a FileStore file - missing Account Network")
|
|
||||||
|
|
||||||
require.NotNil(t, account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"], "failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
|
||||||
|
|
||||||
require.NotNil(t, account.Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore wrong PATs length")
|
|
||||||
|
|
||||||
require.Len(t, store.UserID2AccountID, 2, "failed to restore a FileStore wrong UserID2AccountID mapping length")
|
|
||||||
|
|
||||||
require.Len(t, store.SetupKeyID2AccountID, 1, "failed to restore a FileStore wrong SetupKeyID2AccountID mapping length")
|
|
||||||
|
|
||||||
require.Len(t, store.PrivateDomain2AccountID, 1, "failed to restore a FileStore wrong PrivateDomain2AccountID mapping length")
|
|
||||||
|
|
||||||
require.Len(t, store.HashedPAT2TokenID, 1, "failed to restore a FileStore wrong HashedPAT2TokenID mapping length")
|
|
||||||
|
|
||||||
require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRestoreGroups_Migration(t *testing.T) {
|
|
||||||
storeDir := t.TempDir()
|
|
||||||
|
|
||||||
err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// create default group
|
|
||||||
account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
|
|
||||||
account.Groups = map[string]*group.Group{
|
|
||||||
"cfefqs706sqkneg59g3g": {
|
|
||||||
ID: "cfefqs706sqkneg59g3g",
|
|
||||||
Name: "All",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
err = store.SaveAccount(context.Background(), account)
|
|
||||||
require.NoError(t, err, "failed to save account")
|
|
||||||
|
|
||||||
// restore account with default group with empty Issue field
|
|
||||||
if store, err = NewFileStore(context.Background(), storeDir, nil); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
|
|
||||||
|
|
||||||
require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups")
|
|
||||||
require.Equal(t, group.GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetAccountByPrivateDomain(t *testing.T) {
|
|
||||||
storeDir := t.TempDir()
|
|
||||||
|
|
||||||
err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
existingDomain := "test.com"
|
|
||||||
|
|
||||||
account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain)
|
|
||||||
require.NoError(t, err, "should found account")
|
|
||||||
require.Equal(t, existingDomain, account.Domain, "domains should match")
|
|
||||||
|
|
||||||
_, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com")
|
|
||||||
require.Error(t, err, "should return error on domain lookup")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileStore_GetAccount(t *testing.T) {
|
|
||||||
storeDir := t.TempDir()
|
|
||||||
storeFile := filepath.Join(storeDir, "store.json")
|
|
||||||
err := util.CopyFileContents("testdata/store.json", storeFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
accounts := &accounts{}
|
|
||||||
_, err = util.ReadJson(storeFile, accounts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expected := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
|
|
||||||
if expected == nil {
|
|
||||||
t.Fatalf("expected account doesn't exist")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := store.GetAccount(context.Background(), expected.Id)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, expected.IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
|
|
||||||
assert.Equal(t, expected.DomainCategory, account.DomainCategory)
|
|
||||||
assert.Equal(t, expected.Domain, account.Domain)
|
|
||||||
assert.Equal(t, expected.CreatedBy, account.CreatedBy)
|
|
||||||
assert.Equal(t, expected.Network.Identifier, account.Network.Identifier)
|
|
||||||
assert.Len(t, account.Peers, len(expected.Peers))
|
|
||||||
assert.Len(t, account.Users, len(expected.Users))
|
|
||||||
assert.Len(t, account.SetupKeys, len(expected.SetupKeys))
|
|
||||||
assert.Len(t, account.Routes, len(expected.Routes))
|
|
||||||
assert.Len(t, account.NameServerGroups, len(expected.NameServerGroups))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileStore_GetTokenIDByHashedToken(t *testing.T) {
|
|
||||||
storeDir := t.TempDir()
|
|
||||||
storeFile := filepath.Join(storeDir, "store.json")
|
|
||||||
err := util.CopyFileContents("testdata/store.json", storeFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
accounts := &accounts{}
|
|
||||||
_, err = util.ReadJson(storeFile, accounts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
hashedToken := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].HashedToken
|
|
||||||
tokenID, err := store.GetTokenIDByHashedToken(context.Background(), hashedToken)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedTokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID
|
|
||||||
assert.Equal(t, expectedTokenID, tokenID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) {
|
|
||||||
store := newStore(t)
|
|
||||||
defer store.Close(context.Background())
|
|
||||||
store.HashedPAT2TokenID["someHashedToken"] = "someTokenId"
|
|
||||||
|
|
||||||
err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Empty(t, store.HashedPAT2TokenID["someHashedToken"])
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileStore_DeleteTokenID2UserIDIndex(t *testing.T) {
|
|
||||||
store := newStore(t)
|
|
||||||
store.TokenID2UserID["someTokenId"] = "someUserId"
|
|
||||||
|
|
||||||
err := store.DeleteTokenID2UserIDIndex("someTokenId")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Empty(t, store.TokenID2UserID["someTokenId"])
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileStore_GetTokenIDByHashedToken_Failure(t *testing.T) {
|
|
||||||
storeDir := t.TempDir()
|
|
||||||
storeFile := filepath.Join(storeDir, "store.json")
|
|
||||||
err := util.CopyFileContents("testdata/store.json", storeFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
accounts := &accounts{}
|
|
||||||
_, err = util.ReadJson(storeFile, accounts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
wrongToken := sha256.Sum256([]byte("someNotValidTokenThatFails1234"))
|
|
||||||
_, err = store.GetTokenIDByHashedToken(context.Background(), string(wrongToken[:]))
|
|
||||||
|
|
||||||
assert.Error(t, err, "GetTokenIDByHashedToken should throw error if token invalid")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileStore_GetUserByTokenID(t *testing.T) {
|
|
||||||
storeDir := t.TempDir()
|
|
||||||
storeFile := filepath.Join(storeDir, "store.json")
|
|
||||||
err := util.CopyFileContents("testdata/store.json", storeFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
accounts := &accounts{}
|
|
||||||
_, err = util.ReadJson(storeFile, accounts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID
|
|
||||||
user, err := store.GetUserByTokenID(context.Background(), tokenID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileStore_GetUserByTokenID_Failure(t *testing.T) {
|
|
||||||
storeDir := t.TempDir()
|
|
||||||
storeFile := filepath.Join(storeDir, "store.json")
|
|
||||||
err := util.CopyFileContents("testdata/store.json", storeFile)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
accounts := &accounts{}
|
|
||||||
_, err = util.ReadJson(storeFile, accounts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
wrongTokenID := "someNonExistingTokenID"
|
|
||||||
_, err = store.GetUserByTokenID(context.Background(), wrongTokenID)
|
|
||||||
|
|
||||||
assert.Error(t, err, "GetUserByTokenID should throw error if tokenID invalid")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileStore_SavePeerStatus(t *testing.T) {
|
|
||||||
storeDir := t.TempDir()
|
|
||||||
|
|
||||||
err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := store.getAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// save status of non-existing peer
|
|
||||||
newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}
|
|
||||||
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
|
|
||||||
assert.Error(t, err)
|
|
||||||
|
|
||||||
// save new status of existing peer
|
|
||||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
|
||||||
Key: "peerkey",
|
|
||||||
ID: "testpeer",
|
|
||||||
SetupKey: "peerkeysetupkey",
|
|
||||||
IP: net.IP{127, 0, 0, 1},
|
|
||||||
Meta: nbpeer.PeerSystemMeta{},
|
|
||||||
Name: "peer name",
|
|
||||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
|
|
||||||
}
|
|
||||||
|
|
||||||
err = store.SaveAccount(context.Background(), account)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = store.SavePeerStatus(account.Id, "testpeer", newStatus)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
account, err = store.getAccount(account.Id)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
actual := account.Peers["testpeer"].Status
|
|
||||||
assert.Equal(t, newStatus, *actual)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileStore_SavePeerLocation(t *testing.T) {
|
|
||||||
storeDir := t.TempDir()
|
|
||||||
|
|
||||||
err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
peer := &nbpeer.Peer{
|
|
||||||
AccountID: account.Id,
|
|
||||||
ID: "testpeer",
|
|
||||||
Location: nbpeer.Location{
|
|
||||||
ConnectionIP: net.ParseIP("10.0.0.0"),
|
|
||||||
CountryCode: "YY",
|
|
||||||
CityName: "City",
|
|
||||||
GeoNameID: 1,
|
|
||||||
},
|
|
||||||
Meta: nbpeer.PeerSystemMeta{},
|
|
||||||
}
|
|
||||||
// error is expected as peer is not in store yet
|
|
||||||
err = store.SavePeerLocation(account.Id, peer)
|
|
||||||
assert.Error(t, err)
|
|
||||||
|
|
||||||
account.Peers[peer.ID] = peer
|
|
||||||
err = store.SaveAccount(context.Background(), account)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
peer.Location.ConnectionIP = net.ParseIP("35.1.1.1")
|
|
||||||
peer.Location.CountryCode = "DE"
|
|
||||||
peer.Location.CityName = "Berlin"
|
|
||||||
peer.Location.GeoNameID = 2950159
|
|
||||||
|
|
||||||
err = store.SavePeerLocation(account.Id, account.Peers[peer.ID])
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
account, err = store.GetAccount(context.Background(), account.Id)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
actual := account.Peers[peer.ID].Location
|
|
||||||
assert.Equal(t, peer.Location, actual)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newStore(t *testing.T) *FileStore {
|
|
||||||
t.Helper()
|
|
||||||
store, err := NewFileStore(context.Background(), t.TempDir(), nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed creating a new store")
|
|
||||||
}
|
|
||||||
|
|
||||||
return store
|
|
||||||
}
|
|
@ -6,7 +6,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@ -89,14 +88,7 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error
|
|||||||
|
|
||||||
func Test_SyncProtocol(t *testing.T) {
|
func Test_SyncProtocol(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json"))
|
mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
os.Remove(filepath.Join(dir, "store.json")) //nolint
|
|
||||||
}()
|
|
||||||
mgmtServer, _, mgmtAddr, err := startManagementForTest(t, &Config{
|
|
||||||
Stuns: []*Host{{
|
Stuns: []*Host{{
|
||||||
Proto: "udp",
|
Proto: "udp",
|
||||||
URI: "stun:stun.wiretrustee.com:3468",
|
URI: "stun:stun.wiretrustee.com:3468",
|
||||||
@ -117,6 +109,7 @@ func Test_SyncProtocol(t *testing.T) {
|
|||||||
Datadir: dir,
|
Datadir: dir,
|
||||||
HttpConfig: nil,
|
HttpConfig: nil,
|
||||||
})
|
})
|
||||||
|
defer cleanup()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
@ -412,18 +405,18 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) {
|
func startManagementForTest(t *testing.T, testFile string, config *Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
lis, err := net.Listen("tcp", "localhost:0")
|
lis, err := net.Listen("tcp", "localhost:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, "", err
|
return nil, nil, "", nil, err
|
||||||
}
|
}
|
||||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||||
store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir)
|
|
||||||
|
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), testFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, "", err
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
t.Cleanup(cleanUp)
|
|
||||||
|
|
||||||
peersUpdateManager := NewPeersUpdateManager(nil)
|
peersUpdateManager := NewPeersUpdateManager(nil)
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
@ -437,7 +430,8 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA
|
|||||||
eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, "", err
|
cleanup()
|
||||||
|
return nil, nil, "", cleanup, err
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||||
@ -445,7 +439,7 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA
|
|||||||
ephemeralMgr := NewEphemeralManager(store, accountManager)
|
ephemeralMgr := NewEphemeralManager(store, accountManager)
|
||||||
mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, ephemeralMgr)
|
mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, ephemeralMgr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, "", err
|
return nil, nil, "", cleanup, err
|
||||||
}
|
}
|
||||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||||
|
|
||||||
@ -455,7 +449,7 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return s, accountManager, lis.Addr().String(), nil
|
return s, accountManager, lis.Addr().String(), cleanup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn, error) {
|
func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn, error) {
|
||||||
@ -475,6 +469,7 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie
|
|||||||
|
|
||||||
return mgmtProto.NewManagementServiceClient(conn), conn, nil
|
return mgmtProto.NewManagementServiceClient(conn), conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_SyncStatusRace(t *testing.T) {
|
func Test_SyncStatusRace(t *testing.T) {
|
||||||
if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" {
|
if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" {
|
||||||
t.Skip("Skipping on CI and Postgres store")
|
t.Skip("Skipping on CI and Postgres store")
|
||||||
@ -488,15 +483,8 @@ func Test_SyncStatusRace(t *testing.T) {
|
|||||||
func testSyncStatusRace(t *testing.T) {
|
func testSyncStatusRace(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
os.Remove(filepath.Join(dir, "store.json")) //nolint
|
|
||||||
}()
|
|
||||||
|
|
||||||
mgmtServer, am, mgmtAddr, err := startManagementForTest(t, &Config{
|
mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{
|
||||||
Stuns: []*Host{{
|
Stuns: []*Host{{
|
||||||
Proto: "udp",
|
Proto: "udp",
|
||||||
URI: "stun:stun.wiretrustee.com:3468",
|
URI: "stun:stun.wiretrustee.com:3468",
|
||||||
@ -517,6 +505,7 @@ func testSyncStatusRace(t *testing.T) {
|
|||||||
Datadir: dir,
|
Datadir: dir,
|
||||||
HttpConfig: nil,
|
HttpConfig: nil,
|
||||||
})
|
})
|
||||||
|
defer cleanup()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
@ -665,15 +654,8 @@ func Test_LoginPerformance(t *testing.T) {
|
|||||||
t.Run(bc.name, func(t *testing.T) {
|
t.Run(bc.name, func(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
os.Remove(filepath.Join(dir, "store.json")) //nolint
|
|
||||||
}()
|
|
||||||
|
|
||||||
mgmtServer, am, _, err := startManagementForTest(t, &Config{
|
mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{
|
||||||
Stuns: []*Host{{
|
Stuns: []*Host{{
|
||||||
Proto: "udp",
|
Proto: "udp",
|
||||||
URI: "stun:stun.wiretrustee.com:3468",
|
URI: "stun:stun.wiretrustee.com:3468",
|
||||||
@ -694,6 +676,7 @@ func Test_LoginPerformance(t *testing.T) {
|
|||||||
Datadir: dir,
|
Datadir: dir,
|
||||||
HttpConfig: nil,
|
HttpConfig: nil,
|
||||||
})
|
})
|
||||||
|
defer cleanup()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
|
@ -5,7 +5,6 @@ import (
|
|||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
sync2 "sync"
|
sync2 "sync"
|
||||||
"time"
|
"time"
|
||||||
@ -52,8 +51,6 @@ var _ = Describe("Management service", func() {
|
|||||||
dataDir, err = os.MkdirTemp("", "wiretrustee_mgmt_test_tmp_*")
|
dataDir, err = os.MkdirTemp("", "wiretrustee_mgmt_test_tmp_*")
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
err = util.CopyFileContents("testdata/store.json", filepath.Join(dataDir, "store.json"))
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
var listener net.Listener
|
var listener net.Listener
|
||||||
|
|
||||||
config := &server.Config{}
|
config := &server.Config{}
|
||||||
@ -61,7 +58,7 @@ var _ = Describe("Management service", func() {
|
|||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
config.Datadir = dataDir
|
config.Datadir = dataDir
|
||||||
|
|
||||||
s, listener = startServer(config)
|
s, listener = startServer(config, dataDir, "testdata/store.sqlite")
|
||||||
addr = listener.Addr().String()
|
addr = listener.Addr().String()
|
||||||
client, conn = createRawClient(addr)
|
client, conn = createRawClient(addr)
|
||||||
|
|
||||||
@ -530,12 +527,12 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie
|
|||||||
return mgmtProto.NewManagementServiceClient(conn), conn
|
return mgmtProto.NewManagementServiceClient(conn), conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func startServer(config *server.Config) (*grpc.Server, net.Listener) {
|
func startServer(config *server.Config, dataDir string, testFile string) (*grpc.Server, net.Listener) {
|
||||||
lis, err := net.Listen("tcp", ":0")
|
lis, err := net.Listen("tcp", ":0")
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
|
|
||||||
store, _, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
|
store, _, err := server.NewTestStoreFromSqlite(context.Background(), testFile, dataDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,8 @@ type MockAccountManager struct {
|
|||||||
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
|
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
|
||||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
||||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
||||||
GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error)
|
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
|
||||||
|
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
|
||||||
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
||||||
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||||
@ -57,7 +58,7 @@ type MockAccountManager struct {
|
|||||||
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
||||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||||
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups,accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||||
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||||
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
|
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
|
||||||
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
||||||
@ -193,14 +194,22 @@ func (am *MockAccountManager) CreateSetupKey(
|
|||||||
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface
|
// AccountExists mock implementation of AccountExists from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) {
|
func (am *MockAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
|
||||||
if am.GetAccountIDByUserOrAccountIdFunc != nil {
|
if am.AccountExistsFunc != nil {
|
||||||
return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain)
|
return am.AccountExistsFunc(ctx, accountID)
|
||||||
|
}
|
||||||
|
return false, status.Errorf(codes.Unimplemented, "method AccountExists is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface
|
||||||
|
func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) {
|
||||||
|
if am.GetAccountIDByUserIdFunc != nil {
|
||||||
|
return am.GetAccountIDByUserIdFunc(ctx, userId, domain)
|
||||||
}
|
}
|
||||||
return "", status.Errorf(
|
return "", status.Errorf(
|
||||||
codes.Unimplemented,
|
codes.Unimplemented,
|
||||||
"method GetAccountIDByUserOrAccountID is not implemented",
|
"method GetAccountIDByUserID is not implemented",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -435,7 +444,7 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID
|
|||||||
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
|
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
|
||||||
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
||||||
if am.CreateRouteFunc != nil {
|
if am.CreateRouteFunc != nil {
|
||||||
return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups,accessControlGroupID, enabled, userID, keepRoute)
|
return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
|
||||||
}
|
}
|
||||||
|
@ -775,7 +775,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
|||||||
func createNSStore(t *testing.T) (Store, error) {
|
func createNSStore(t *testing.T) (Store, error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
dataDir := t.TempDir()
|
dataDir := t.TempDir()
|
||||||
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
|
store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -707,6 +707,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
|||||||
updateRemotePeers := false
|
updateRemotePeers := false
|
||||||
|
|
||||||
if login.UserID != "" {
|
if login.UserID != "" {
|
||||||
|
if peer.UserID != login.UserID {
|
||||||
|
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
|
||||||
|
return nil, nil, nil, status.Errorf(status.Unauthenticated, "invalid user")
|
||||||
|
}
|
||||||
|
|
||||||
changed, err := am.handleUserPeer(ctx, peer, settings)
|
changed, err := am.handleUserPeer(ctx, peer, settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
|
@ -1004,7 +1004,11 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
|||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
}
|
}
|
||||||
|
|
||||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
|
|
||||||
@ -1065,7 +1069,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
|||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
}
|
}
|
||||||
|
|
||||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
|
|
||||||
@ -1127,7 +1135,11 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
|||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
}
|
}
|
||||||
|
|
||||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
|
|
||||||
|
@ -1258,7 +1258,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
|||||||
func createRouterStore(t *testing.T) (Store, error) {
|
func createRouterStore(t *testing.T) (Store, error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
dataDir := t.TempDir()
|
dataDir := t.TempDir()
|
||||||
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
|
store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -1738,7 +1738,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)
|
assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)
|
||||||
|
|
||||||
//peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
|
// peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
|
||||||
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
|
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
|
||||||
assert.Len(t, routesFirewallRules, 2)
|
assert.Len(t, routesFirewallRules, 2)
|
||||||
assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)
|
assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -63,8 +64,14 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
conns := runtime.NumCPU()
|
|
||||||
sql.SetMaxOpenConns(conns) // TODO: make it configurable
|
conns, err := strconv.Atoi(os.Getenv("NB_SQL_MAX_OPEN_CONNS"))
|
||||||
|
if err != nil {
|
||||||
|
conns = runtime.NumCPU()
|
||||||
|
}
|
||||||
|
sql.SetMaxOpenConns(conns)
|
||||||
|
|
||||||
|
log.Infof("Set max open db connections to %d", conns)
|
||||||
|
|
||||||
if err := migrate(ctx, db); err != nil {
|
if err := migrate(ctx, db); err != nil {
|
||||||
return nil, fmt.Errorf("migrate: %w", err)
|
return nil, fmt.Errorf("migrate: %w", err)
|
||||||
@ -378,15 +385,26 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
|
|||||||
Create(&usersToSave).Error
|
Create(&usersToSave).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveGroups saves the given list of groups to the database.
|
// SaveUser saves the given user to the database.
|
||||||
// It updates existing groups if a conflict occurs.
|
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
|
||||||
func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
|
||||||
groupsToSave := make([]nbgroup.Group, 0, len(groups))
|
if result.Error != nil {
|
||||||
for _, group := range groups {
|
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
|
||||||
group.AccountID = accountID
|
|
||||||
groupsToSave = append(groupsToSave, *group)
|
|
||||||
}
|
}
|
||||||
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveGroups saves the given list of groups to the database.
|
||||||
|
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
|
||||||
|
if len(groups) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
|
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
|
||||||
@ -420,7 +438,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
|
|||||||
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
|
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
|
||||||
return "", status.Errorf(status.Internal, "issue getting account from store")
|
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
return accountID, nil
|
return accountID, nil
|
||||||
@ -433,7 +451,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
|
|||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
return nil, status.NewSetupKeyNotFoundError()
|
return nil, status.NewSetupKeyNotFoundError(result.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
if key.AccountID == "" {
|
if key.AccountID == "" {
|
||||||
@ -451,7 +469,7 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
|
|||||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
||||||
return "", status.Errorf(status.Internal, "issue getting account from store")
|
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
return token.ID, nil
|
return token.ID, nil
|
||||||
@ -465,7 +483,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
|
|||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
||||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
if token.UserID == "" {
|
if token.UserID == "" {
|
||||||
@ -549,7 +567,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
|
|||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.NewAccountNotFoundError(accountID)
|
return nil, status.NewAccountNotFoundError(accountID)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
|
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
|
||||||
@ -612,7 +630,7 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
|
|||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.AccountID == "" {
|
if user.AccountID == "" {
|
||||||
@ -629,7 +647,7 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
|
|||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.AccountID == "" {
|
if peer.AccountID == "" {
|
||||||
@ -647,7 +665,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
|
|||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(status.Internal, "issue getting account from store")
|
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.AccountID == "" {
|
if peer.AccountID == "" {
|
||||||
@ -665,7 +683,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
|
|||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
return "", status.Errorf(status.Internal, "issue getting account from store")
|
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
return accountID, nil
|
return accountID, nil
|
||||||
@ -678,7 +696,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
|||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
return "", status.Errorf(status.Internal, "issue getting account from store")
|
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
return accountID, nil
|
return accountID, nil
|
||||||
@ -691,7 +709,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
|
|||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||||
}
|
}
|
||||||
return "", status.NewSetupKeyNotFoundError()
|
return "", status.NewSetupKeyNotFoundError(result.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
if accountID == "" {
|
if accountID == "" {
|
||||||
@ -712,7 +730,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
|
|||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(status.Internal, "issue getting IPs from store")
|
return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert the JSON strings to net.IP objects
|
// Convert the JSON strings to net.IP objects
|
||||||
@ -740,7 +758,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
|
|||||||
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
return nil, status.Errorf(status.NotFound, "no peers found for the account")
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
|
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
|
||||||
return nil, status.Errorf(status.Internal, "issue getting dns labels from store")
|
return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
return labels, nil
|
return labels, nil
|
||||||
@ -753,7 +771,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
|
|||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.NewAccountNotFoundError(accountID)
|
return nil, status.NewAccountNotFoundError(accountID)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(status.Internal, "issue getting network from store")
|
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
|
||||||
}
|
}
|
||||||
return accountNetwork.Network, nil
|
return accountNetwork.Network, nil
|
||||||
}
|
}
|
||||||
@ -765,7 +783,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
|
|||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "peer not found")
|
return nil, status.Errorf(status.NotFound, "peer not found")
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(status.Internal, "issue getting peer from store")
|
return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &peer, nil
|
return &peer, nil
|
||||||
@ -777,7 +795,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
|
|||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "settings not found")
|
return nil, status.Errorf(status.NotFound, "settings not found")
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(status.Internal, "issue getting settings from store")
|
return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err)
|
||||||
}
|
}
|
||||||
return accountSettings.Settings, nil
|
return accountSettings.Settings, nil
|
||||||
}
|
}
|
||||||
@ -915,6 +933,28 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore,
|
|||||||
return store, nil
|
return store, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB.
|
||||||
|
func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||||
|
store, err := NewPostgresqlStore(ctx, dsn, metrics)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = store.SaveInstallationID(ctx, sqliteStore.GetInstallationID())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, account := range sqliteStore.GetAllAccounts(ctx) {
|
||||||
|
err := store.SaveAccount(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return store, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
|
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
|
||||||
var setupKey SetupKey
|
var setupKey SetupKey
|
||||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
@ -923,7 +963,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
|
|||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||||
}
|
}
|
||||||
return nil, status.NewSetupKeyNotFoundError()
|
return nil, status.NewSetupKeyNotFoundError(result.Error)
|
||||||
}
|
}
|
||||||
return &setupKey, nil
|
return &setupKey, nil
|
||||||
}
|
}
|
||||||
@ -955,7 +995,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
|
|||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return status.Errorf(status.NotFound, "group 'All' not found for account")
|
return status.Errorf(status.NotFound, "group 'All' not found for account")
|
||||||
}
|
}
|
||||||
return status.Errorf(status.Internal, "issue finding group 'All'")
|
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, existingPeerID := range group.Peers {
|
for _, existingPeerID := range group.Peers {
|
||||||
@ -967,7 +1007,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
|
|||||||
group.Peers = append(group.Peers, peerID)
|
group.Peers = append(group.Peers, peerID)
|
||||||
|
|
||||||
if err := s.db.Save(&group).Error; err != nil {
|
if err := s.db.Save(&group).Error; err != nil {
|
||||||
return status.Errorf(status.Internal, "issue updating group 'All'")
|
return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -981,7 +1021,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
|
|||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return status.Errorf(status.NotFound, "group not found for account")
|
return status.Errorf(status.NotFound, "group not found for account")
|
||||||
}
|
}
|
||||||
return status.Errorf(status.Internal, "issue finding group")
|
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, existingPeerID := range group.Peers {
|
for _, existingPeerID := range group.Peers {
|
||||||
@ -993,15 +1033,20 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
|
|||||||
group.Peers = append(group.Peers, peerId)
|
group.Peers = append(group.Peers, peerId)
|
||||||
|
|
||||||
if err := s.db.Save(&group).Error; err != nil {
|
if err := s.db.Save(&group).Error; err != nil {
|
||||||
return status.Errorf(status.Internal, "issue updating group")
|
return status.Errorf(status.Internal, "issue updating group: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserPeers retrieves peers for a user.
|
||||||
|
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||||
|
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||||
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
|
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
|
||||||
return status.Errorf(status.Internal, "issue adding peer to account")
|
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -1010,7 +1055,7 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
|
|||||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||||
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return status.Errorf(status.Internal, "issue incrementing network serial count")
|
return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -1105,6 +1150,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
|||||||
return &group, nil
|
return &group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveGroup saves a group to the store.
|
||||||
|
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetAccountPolicies retrieves policies for an account.
|
// GetAccountPolicies retrieves policies for an account.
|
||||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
||||||
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
|
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
|
||||||
|
@ -7,7 +7,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -25,7 +24,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
|
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSqlite_NewStore(t *testing.T) {
|
func TestSqlite_NewStore(t *testing.T) {
|
||||||
@ -347,7 +345,11 @@ func TestSqlite_GetAccount(t *testing.T) {
|
|||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
}
|
}
|
||||||
|
|
||||||
store := newSqliteStoreFromFile(t, "testdata/store.json")
|
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
id := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
id := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
@ -367,7 +369,11 @@ func TestSqlite_SavePeer(t *testing.T) {
|
|||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
}
|
}
|
||||||
|
|
||||||
store := newSqliteStoreFromFile(t, "testdata/store.json")
|
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -415,7 +421,11 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
|
|||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
}
|
}
|
||||||
|
|
||||||
store := newSqliteStoreFromFile(t, "testdata/store.json")
|
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
|
||||||
|
defer cleanup()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -468,8 +478,11 @@ func TestSqlite_SavePeerLocation(t *testing.T) {
|
|||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
}
|
}
|
||||||
|
|
||||||
store := newSqliteStoreFromFile(t, "testdata/store.json")
|
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
|
||||||
|
defer cleanup()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@ -519,8 +532,11 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) {
|
|||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
}
|
}
|
||||||
|
|
||||||
store := newSqliteStoreFromFile(t, "testdata/store.json")
|
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
|
||||||
|
defer cleanup()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
existingDomain := "test.com"
|
existingDomain := "test.com"
|
||||||
|
|
||||||
account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain)
|
account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain)
|
||||||
@ -539,8 +555,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
|
|||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
}
|
}
|
||||||
|
|
||||||
store := newSqliteStoreFromFile(t, "testdata/store.json")
|
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
|
||||||
|
defer cleanup()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
hashed := "SoMeHaShEdToKeN"
|
hashed := "SoMeHaShEdToKeN"
|
||||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||||
|
|
||||||
@ -560,8 +579,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
|
|||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
}
|
}
|
||||||
|
|
||||||
store := newSqliteStoreFromFile(t, "testdata/store.json")
|
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite")
|
||||||
|
defer cleanup()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||||
|
|
||||||
user, err := store.GetUserByTokenID(context.Background(), id)
|
user, err := store.GetUserByTokenID(context.Background(), id)
|
||||||
@ -668,24 +690,9 @@ func newSqliteStore(t *testing.T) *SqlStore {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
store, err := NewSqliteStore(context.Background(), t.TempDir(), nil)
|
store, err := NewSqliteStore(context.Background(), t.TempDir(), nil)
|
||||||
require.NoError(t, err)
|
t.Cleanup(func() {
|
||||||
require.NotNil(t, store)
|
store.Close(context.Background())
|
||||||
|
})
|
||||||
return store
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSqliteStoreFromFile(t *testing.T, filename string) *SqlStore {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
storeDir := t.TempDir()
|
|
||||||
|
|
||||||
err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
fStore, err := NewFileStore(context.Background(), storeDir, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
store, err := NewSqliteStoreFromFileStore(context.Background(), fStore, storeDir, nil)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, store)
|
require.NotNil(t, store)
|
||||||
|
|
||||||
@ -733,32 +740,31 @@ func newPostgresqlStore(t *testing.T) *SqlStore {
|
|||||||
return store
|
return store
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPostgresqlStoreFromFile(t *testing.T, filename string) *SqlStore {
|
func newPostgresqlStoreFromSqlite(t *testing.T, filename string) *SqlStore {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
storeDir := t.TempDir()
|
store, cleanUpQ, err := NewSqliteTestStore(context.Background(), t.TempDir(), filename)
|
||||||
err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json"))
|
t.Cleanup(cleanUpQ)
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
fStore, err := NewFileStore(context.Background(), storeDir, nil)
|
cleanUpP, err := testutil.CreatePGDB()
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
cleanUp, err := testutil.CreatePGDB()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
t.Cleanup(cleanUp)
|
t.Cleanup(cleanUpP)
|
||||||
|
|
||||||
postgresDsn, ok := os.LookupEnv(postgresDsnEnv)
|
postgresDsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv)
|
t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv)
|
||||||
}
|
}
|
||||||
|
|
||||||
store, err := NewPostgresqlStoreFromFileStore(context.Background(), fStore, postgresDsn, nil)
|
pstore, err := NewPostgresqlStoreFromSqlStore(context.Background(), store, postgresDsn, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, store)
|
require.NotNil(t, store)
|
||||||
|
|
||||||
return store
|
return pstore
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPostgresql_NewStore(t *testing.T) {
|
func TestPostgresql_NewStore(t *testing.T) {
|
||||||
@ -924,7 +930,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) {
|
|||||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
store := newPostgresqlStoreFromFile(t, "testdata/store.json")
|
store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
|
||||||
|
|
||||||
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -963,7 +969,7 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
|
|||||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
store := newPostgresqlStoreFromFile(t, "testdata/store.json")
|
store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
|
||||||
|
|
||||||
existingDomain := "test.com"
|
existingDomain := "test.com"
|
||||||
|
|
||||||
@ -980,7 +986,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
|
|||||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
store := newPostgresqlStoreFromFile(t, "testdata/store.json")
|
store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
|
||||||
|
|
||||||
hashed := "SoMeHaShEdToKeN"
|
hashed := "SoMeHaShEdToKeN"
|
||||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||||
@ -995,7 +1001,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
|||||||
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
store := newPostgresqlStoreFromFile(t, "testdata/store.json")
|
store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite")
|
||||||
|
|
||||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||||
|
|
||||||
@ -1009,12 +1015,15 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
|
|||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
}
|
}
|
||||||
|
|
||||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||||
defer store.Close(context.Background())
|
defer cleanup()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
_, err := store.GetAccount(context.Background(), existingAccountID)
|
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
|
takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
|
||||||
@ -1054,12 +1063,15 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
|
|||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
}
|
}
|
||||||
|
|
||||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||||
defer store.Close(context.Background())
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
_, err := store.GetAccount(context.Background(), existingAccountID)
|
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
|
||||||
@ -1096,12 +1108,15 @@ func TestSqlite_GetAccountNetwork(t *testing.T) {
|
|||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
}
|
}
|
||||||
|
|
||||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||||
defer store.Close(context.Background())
|
t.Cleanup(cleanup)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
_, err := store.GetAccount(context.Background(), existingAccountID)
|
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID)
|
network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID)
|
||||||
@ -1118,12 +1133,15 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
|
|||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
}
|
}
|
||||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||||
defer store.Close(context.Background())
|
t.Cleanup(cleanup)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
_, err := store.GetAccount(context.Background(), existingAccountID)
|
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||||
@ -1137,12 +1155,16 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
|
|||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
}
|
}
|
||||||
store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
|
|
||||||
defer store.Close(context.Background())
|
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
_, err := store.GetAccount(context.Background(), existingAccountID)
|
_, err = store.GetAccount(context.Background(), existingAccountID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||||
@ -1163,3 +1185,33 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 2, setupKey.UsedTimes)
|
assert.Equal(t, 2, setupKey.UsedTimes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
||||||
|
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
group := &nbgroup.Group{
|
||||||
|
ID: "group-id",
|
||||||
|
AccountID: "account-id",
|
||||||
|
Name: "group-name",
|
||||||
|
Issued: "api",
|
||||||
|
Peers: nil,
|
||||||
|
}
|
||||||
|
err = store.ExecuteInTransaction(context.Background(), func(transaction Store) error {
|
||||||
|
err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to save group")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to get group")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
t.Logf("group: %v", group)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
@ -102,8 +102,12 @@ func NewPeerLoginExpiredError() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
|
// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
|
||||||
func NewSetupKeyNotFoundError() error {
|
func NewSetupKeyNotFoundError(err error) error {
|
||||||
return Errorf(NotFound, "setup key not found")
|
return Errorf(NotFound, "setup key not found: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGetAccountFromStoreError(err error) error {
|
||||||
|
return Errorf(Internal, "issue getting account from store: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
|
// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
|
||||||
|
@ -12,10 +12,11 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/dns"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/dns"
|
||||||
|
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
@ -59,6 +60,7 @@ type Store interface {
|
|||||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||||
SaveUsers(accountID string, users map[string]*User) error
|
SaveUsers(accountID string, users map[string]*User) error
|
||||||
|
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
||||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||||
@ -67,7 +69,8 @@ type Store interface {
|
|||||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
|
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
|
||||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
||||||
SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
|
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
|
||||||
|
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
|
||||||
|
|
||||||
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
||||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
||||||
@ -81,6 +84,7 @@ type Store interface {
|
|||||||
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
||||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||||
|
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||||
@ -236,23 +240,29 @@ func getMigrations(ctx context.Context) []migrationFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTestStoreFromJson is only used in tests
|
// NewTestStoreFromSqlite is only used in tests
|
||||||
func NewTestStoreFromJson(ctx context.Context, dataDir string) (Store, func(), error) {
|
func NewTestStoreFromSqlite(ctx context.Context, filename string, dataDir string) (Store, func(), error) {
|
||||||
fstore, err := NewFileStore(ctx, dataDir, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE
|
// if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE
|
||||||
kind := getStoreEngineFromEnv()
|
kind := getStoreEngineFromEnv()
|
||||||
if kind == "" {
|
if kind == "" {
|
||||||
kind = SqliteStoreEngine
|
kind = SqliteStoreEngine
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var store *SqlStore
|
||||||
store Store
|
var err error
|
||||||
cleanUp func()
|
var cleanUp func()
|
||||||
)
|
|
||||||
|
if filename == "" {
|
||||||
|
store, err = NewSqliteStore(ctx, dataDir, nil)
|
||||||
|
cleanUp = func() {
|
||||||
|
store.Close(ctx)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
store, cleanUp, err = NewSqliteTestStore(ctx, dataDir, filename)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if kind == PostgresStoreEngine {
|
if kind == PostgresStoreEngine {
|
||||||
cleanUp, err = testutil.CreatePGDB()
|
cleanUp, err = testutil.CreatePGDB()
|
||||||
@ -265,21 +275,32 @@ func NewTestStoreFromJson(ctx context.Context, dataDir string) (Store, func(), e
|
|||||||
return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv)
|
return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv)
|
||||||
}
|
}
|
||||||
|
|
||||||
store, err = NewPostgresqlStoreFromFileStore(ctx, fstore, dsn, nil)
|
store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
store, err = NewSqliteStoreFromFileStore(ctx, fstore, dataDir, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
cleanUp = func() { store.Close(ctx) }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return store, cleanUp, nil
|
return store, cleanUp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewSqliteTestStore(ctx context.Context, dataDir string, testFile string) (*SqlStore, func(), error) {
|
||||||
|
err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
store, err := NewSqliteStore(ctx, dataDir, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return store, func() {
|
||||||
|
store.Close(ctx)
|
||||||
|
os.Remove(filepath.Join(dataDir, "store.db"))
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// MigrateFileStoreToSqlite migrates the file store to the SQLite store.
|
// MigrateFileStoreToSqlite migrates the file store to the SQLite store.
|
||||||
func MigrateFileStoreToSqlite(ctx context.Context, dataDir string) error {
|
func MigrateFileStoreToSqlite(ctx context.Context, dataDir string) error {
|
||||||
fileStorePath := path.Join(dataDir, storeFileName)
|
fileStorePath := path.Join(dataDir, storeFileName)
|
||||||
|
@ -14,12 +14,6 @@ type benchCase struct {
|
|||||||
size int
|
size int
|
||||||
}
|
}
|
||||||
|
|
||||||
var newFs = func(b *testing.B) Store {
|
|
||||||
b.Helper()
|
|
||||||
store, _ := NewFileStore(context.Background(), b.TempDir(), nil)
|
|
||||||
return store
|
|
||||||
}
|
|
||||||
|
|
||||||
var newSqlite = func(b *testing.B) Store {
|
var newSqlite = func(b *testing.B) Store {
|
||||||
b.Helper()
|
b.Helper()
|
||||||
store, _ := NewSqliteStore(context.Background(), b.TempDir(), nil)
|
store, _ := NewSqliteStore(context.Background(), b.TempDir(), nil)
|
||||||
@ -28,13 +22,9 @@ var newSqlite = func(b *testing.B) Store {
|
|||||||
|
|
||||||
func BenchmarkTest_StoreWrite(b *testing.B) {
|
func BenchmarkTest_StoreWrite(b *testing.B) {
|
||||||
cases := []benchCase{
|
cases := []benchCase{
|
||||||
{name: "FileStore_Write", storeFn: newFs, size: 100},
|
|
||||||
{name: "SqliteStore_Write", storeFn: newSqlite, size: 100},
|
{name: "SqliteStore_Write", storeFn: newSqlite, size: 100},
|
||||||
{name: "FileStore_Write", storeFn: newFs, size: 500},
|
|
||||||
{name: "SqliteStore_Write", storeFn: newSqlite, size: 500},
|
{name: "SqliteStore_Write", storeFn: newSqlite, size: 500},
|
||||||
{name: "FileStore_Write", storeFn: newFs, size: 1000},
|
|
||||||
{name: "SqliteStore_Write", storeFn: newSqlite, size: 1000},
|
{name: "SqliteStore_Write", storeFn: newSqlite, size: 1000},
|
||||||
{name: "FileStore_Write", storeFn: newFs, size: 2000},
|
|
||||||
{name: "SqliteStore_Write", storeFn: newSqlite, size: 2000},
|
{name: "SqliteStore_Write", storeFn: newSqlite, size: 2000},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -61,11 +51,8 @@ func BenchmarkTest_StoreWrite(b *testing.B) {
|
|||||||
|
|
||||||
func BenchmarkTest_StoreRead(b *testing.B) {
|
func BenchmarkTest_StoreRead(b *testing.B) {
|
||||||
cases := []benchCase{
|
cases := []benchCase{
|
||||||
{name: "FileStore_Read", storeFn: newFs, size: 100},
|
|
||||||
{name: "SqliteStore_Read", storeFn: newSqlite, size: 100},
|
{name: "SqliteStore_Read", storeFn: newSqlite, size: 100},
|
||||||
{name: "FileStore_Read", storeFn: newFs, size: 500},
|
|
||||||
{name: "SqliteStore_Read", storeFn: newSqlite, size: 500},
|
{name: "SqliteStore_Read", storeFn: newSqlite, size: 500},
|
||||||
{name: "FileStore_Read", storeFn: newFs, size: 1000},
|
|
||||||
{name: "SqliteStore_Read", storeFn: newSqlite, size: 1000},
|
{name: "SqliteStore_Read", storeFn: newSqlite, size: 1000},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,3 +76,11 @@ func BenchmarkTest_StoreRead(b *testing.B) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newStore(t *testing.T) Store {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
store := newSqliteStore(t)
|
||||||
|
|
||||||
|
return store
|
||||||
|
}
|
||||||
|
120
management/server/testdata/extended-store.json
vendored
120
management/server/testdata/extended-store.json
vendored
@ -1,120 +0,0 @@
|
|||||||
{
|
|
||||||
"Accounts": {
|
|
||||||
"bf1c8084-ba50-4ce7-9439-34653001fc3b": {
|
|
||||||
"Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
|
||||||
"CreatedBy": "",
|
|
||||||
"Domain": "test.com",
|
|
||||||
"DomainCategory": "private",
|
|
||||||
"IsDomainPrimaryAccount": true,
|
|
||||||
"SetupKeys": {
|
|
||||||
"A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
|
|
||||||
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
|
||||||
"AccountID": "",
|
|
||||||
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
|
||||||
"Name": "Default key",
|
|
||||||
"Type": "reusable",
|
|
||||||
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
|
|
||||||
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
|
|
||||||
"UpdatedAt": "0001-01-01T00:00:00Z",
|
|
||||||
"Revoked": false,
|
|
||||||
"UsedTimes": 0,
|
|
||||||
"LastUsed": "0001-01-01T00:00:00Z",
|
|
||||||
"AutoGroups": ["cfefqs706sqkneg59g2g"],
|
|
||||||
"UsageLimit": 0,
|
|
||||||
"Ephemeral": false
|
|
||||||
},
|
|
||||||
"A2C8E62B-38F5-4553-B31E-DD66C696CEBC": {
|
|
||||||
"Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
|
|
||||||
"AccountID": "",
|
|
||||||
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
|
|
||||||
"Name": "Faulty key with non existing group",
|
|
||||||
"Type": "reusable",
|
|
||||||
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
|
|
||||||
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
|
|
||||||
"UpdatedAt": "0001-01-01T00:00:00Z",
|
|
||||||
"Revoked": false,
|
|
||||||
"UsedTimes": 0,
|
|
||||||
"LastUsed": "0001-01-01T00:00:00Z",
|
|
||||||
"AutoGroups": ["abcd"],
|
|
||||||
"UsageLimit": 0,
|
|
||||||
"Ephemeral": false
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Network": {
|
|
||||||
"id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
|
|
||||||
"Net": {
|
|
||||||
"IP": "100.64.0.0",
|
|
||||||
"Mask": "//8AAA=="
|
|
||||||
},
|
|
||||||
"Dns": "",
|
|
||||||
"Serial": 0
|
|
||||||
},
|
|
||||||
"Peers": {},
|
|
||||||
"Users": {
|
|
||||||
"edafee4e-63fb-11ec-90d6-0242ac120003": {
|
|
||||||
"Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
|
|
||||||
"AccountID": "",
|
|
||||||
"Role": "admin",
|
|
||||||
"IsServiceUser": false,
|
|
||||||
"ServiceUserName": "",
|
|
||||||
"AutoGroups": ["cfefqs706sqkneg59g3g"],
|
|
||||||
"PATs": {},
|
|
||||||
"Blocked": false,
|
|
||||||
"LastLogin": "0001-01-01T00:00:00Z"
|
|
||||||
},
|
|
||||||
"f4f6d672-63fb-11ec-90d6-0242ac120003": {
|
|
||||||
"Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
|
|
||||||
"AccountID": "",
|
|
||||||
"Role": "user",
|
|
||||||
"IsServiceUser": false,
|
|
||||||
"ServiceUserName": "",
|
|
||||||
"AutoGroups": null,
|
|
||||||
"PATs": {
|
|
||||||
"9dj38s35-63fb-11ec-90d6-0242ac120003": {
|
|
||||||
"ID": "9dj38s35-63fb-11ec-90d6-0242ac120003",
|
|
||||||
"UserID": "",
|
|
||||||
"Name": "",
|
|
||||||
"HashedToken": "SoMeHaShEdToKeN",
|
|
||||||
"ExpirationDate": "2023-02-27T00:00:00Z",
|
|
||||||
"CreatedBy": "user",
|
|
||||||
"CreatedAt": "2023-01-01T00:00:00Z",
|
|
||||||
"LastUsed": "2023-02-01T00:00:00Z"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Blocked": false,
|
|
||||||
"LastLogin": "0001-01-01T00:00:00Z"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Groups": {
|
|
||||||
"cfefqs706sqkneg59g4g": {
|
|
||||||
"ID": "cfefqs706sqkneg59g4g",
|
|
||||||
"Name": "All",
|
|
||||||
"Peers": []
|
|
||||||
},
|
|
||||||
"cfefqs706sqkneg59g3g": {
|
|
||||||
"ID": "cfefqs706sqkneg59g3g",
|
|
||||||
"Name": "AwesomeGroup1",
|
|
||||||
"Peers": []
|
|
||||||
},
|
|
||||||
"cfefqs706sqkneg59g2g": {
|
|
||||||
"ID": "cfefqs706sqkneg59g2g",
|
|
||||||
"Name": "AwesomeGroup2",
|
|
||||||
"Peers": []
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Rules": null,
|
|
||||||
"Policies": [],
|
|
||||||
"Routes": null,
|
|
||||||
"NameServerGroups": null,
|
|
||||||
"DNSSettings": null,
|
|
||||||
"Settings": {
|
|
||||||
"PeerLoginExpirationEnabled": false,
|
|
||||||
"PeerLoginExpiration": 86400000000000,
|
|
||||||
"GroupsPropagationEnabled": false,
|
|
||||||
"JWTGroupsEnabled": false,
|
|
||||||
"JWTGroupsClaimName": ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"InstallationID": ""
|
|
||||||
}
|
|
BIN
management/server/testdata/extended-store.sqlite
vendored
Normal file
BIN
management/server/testdata/extended-store.sqlite
vendored
Normal file
Binary file not shown.
120
management/server/testdata/store.json
vendored
120
management/server/testdata/store.json
vendored
@ -1,120 +0,0 @@
|
|||||||
{
|
|
||||||
"Accounts": {
|
|
||||||
"bf1c8084-ba50-4ce7-9439-34653001fc3b": {
|
|
||||||
"Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
|
||||||
"CreatedBy": "",
|
|
||||||
"Domain": "test.com",
|
|
||||||
"DomainCategory": "private",
|
|
||||||
"IsDomainPrimaryAccount": true,
|
|
||||||
"SetupKeys": {
|
|
||||||
"A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
|
|
||||||
"Id": "",
|
|
||||||
"AccountID": "",
|
|
||||||
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
|
||||||
"Name": "Default key",
|
|
||||||
"Type": "reusable",
|
|
||||||
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
|
|
||||||
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
|
|
||||||
"UpdatedAt": "0001-01-01T00:00:00Z",
|
|
||||||
"Revoked": false,
|
|
||||||
"UsedTimes": 0,
|
|
||||||
"LastUsed": "0001-01-01T00:00:00Z",
|
|
||||||
"AutoGroups": ["cq9bbkjjuspi5gd38epg"],
|
|
||||||
"UsageLimit": 0,
|
|
||||||
"Ephemeral": false
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Network": {
|
|
||||||
"id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
|
|
||||||
"Net": {
|
|
||||||
"IP": "100.64.0.0",
|
|
||||||
"Mask": "//8AAA=="
|
|
||||||
},
|
|
||||||
"Dns": "",
|
|
||||||
"Serial": 0
|
|
||||||
},
|
|
||||||
"Peers": {},
|
|
||||||
"Users": {
|
|
||||||
"edafee4e-63fb-11ec-90d6-0242ac120003": {
|
|
||||||
"Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
|
|
||||||
"AccountID": "",
|
|
||||||
"Role": "admin",
|
|
||||||
"IsServiceUser": false,
|
|
||||||
"ServiceUserName": "",
|
|
||||||
"AutoGroups": null,
|
|
||||||
"PATs": {},
|
|
||||||
"Blocked": false,
|
|
||||||
"LastLogin": "0001-01-01T00:00:00Z"
|
|
||||||
},
|
|
||||||
"f4f6d672-63fb-11ec-90d6-0242ac120003": {
|
|
||||||
"Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
|
|
||||||
"AccountID": "",
|
|
||||||
"Role": "user",
|
|
||||||
"IsServiceUser": false,
|
|
||||||
"ServiceUserName": "",
|
|
||||||
"AutoGroups": null,
|
|
||||||
"PATs": {
|
|
||||||
"9dj38s35-63fb-11ec-90d6-0242ac120003": {
|
|
||||||
"ID": "9dj38s35-63fb-11ec-90d6-0242ac120003",
|
|
||||||
"UserID": "",
|
|
||||||
"Name": "",
|
|
||||||
"HashedToken": "SoMeHaShEdToKeN",
|
|
||||||
"ExpirationDate": "2023-02-27T00:00:00Z",
|
|
||||||
"CreatedBy": "user",
|
|
||||||
"CreatedAt": "2023-01-01T00:00:00Z",
|
|
||||||
"LastUsed": "2023-02-01T00:00:00Z"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Blocked": false,
|
|
||||||
"LastLogin": "0001-01-01T00:00:00Z"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Groups": {
|
|
||||||
"cq9bbkjjuspi5gd38epg": {
|
|
||||||
"ID": "cq9bbkjjuspi5gd38epg",
|
|
||||||
"Name": "All",
|
|
||||||
"Peers": []
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Rules": null,
|
|
||||||
"Policies": [
|
|
||||||
{
|
|
||||||
"ID": "cq9bbkjjuspi5gd38eq0",
|
|
||||||
"Name": "Default",
|
|
||||||
"Description": "This is a default rule that allows connections between all the resources",
|
|
||||||
"Enabled": true,
|
|
||||||
"Rules": [
|
|
||||||
{
|
|
||||||
"ID": "cq9bbkjjuspi5gd38eq0",
|
|
||||||
"Name": "Default",
|
|
||||||
"Description": "This is a default rule that allows connections between all the resources",
|
|
||||||
"Enabled": true,
|
|
||||||
"Action": "accept",
|
|
||||||
"Destinations": [
|
|
||||||
"cq9bbkjjuspi5gd38epg"
|
|
||||||
],
|
|
||||||
"Sources": [
|
|
||||||
"cq9bbkjjuspi5gd38epg"
|
|
||||||
],
|
|
||||||
"Bidirectional": true,
|
|
||||||
"Protocol": "all",
|
|
||||||
"Ports": null
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"SourcePostureChecks": null
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"Routes": null,
|
|
||||||
"NameServerGroups": null,
|
|
||||||
"DNSSettings": null,
|
|
||||||
"Settings": {
|
|
||||||
"PeerLoginExpirationEnabled": false,
|
|
||||||
"PeerLoginExpiration": 86400000000000,
|
|
||||||
"GroupsPropagationEnabled": false,
|
|
||||||
"JWTGroupsEnabled": false,
|
|
||||||
"JWTGroupsClaimName": ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"InstallationID": ""
|
|
||||||
}
|
|
BIN
management/server/testdata/store.sqlite
vendored
Normal file
BIN
management/server/testdata/store.sqlite
vendored
Normal file
Binary file not shown.
116
management/server/testdata/store_policy_migrate.json
vendored
116
management/server/testdata/store_policy_migrate.json
vendored
@ -1,116 +0,0 @@
|
|||||||
{
|
|
||||||
"Accounts": {
|
|
||||||
"bf1c8084-ba50-4ce7-9439-34653001fc3b": {
|
|
||||||
"Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
|
||||||
"Domain": "test.com",
|
|
||||||
"DomainCategory": "private",
|
|
||||||
"IsDomainPrimaryAccount": true,
|
|
||||||
"SetupKeys": {
|
|
||||||
"A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
|
|
||||||
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
|
||||||
"Name": "Default key",
|
|
||||||
"Type": "reusable",
|
|
||||||
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
|
|
||||||
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
|
|
||||||
"Revoked": false,
|
|
||||||
"UsedTimes": 0
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Network": {
|
|
||||||
"Id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
|
|
||||||
"Net": {
|
|
||||||
"IP": "100.64.0.0",
|
|
||||||
"Mask": "//8AAA=="
|
|
||||||
},
|
|
||||||
"Dns": null
|
|
||||||
},
|
|
||||||
"Peers": {
|
|
||||||
"cfefqs706sqkneg59g4g": {
|
|
||||||
"ID": "cfefqs706sqkneg59g4g",
|
|
||||||
"Key": "MI5mHfJhbggPfD3FqEIsXm8X5bSWeUI2LhO9MpEEtWA=",
|
|
||||||
"SetupKey": "",
|
|
||||||
"IP": "100.103.179.238",
|
|
||||||
"Meta": {
|
|
||||||
"Hostname": "Ubuntu-2204-jammy-amd64-base",
|
|
||||||
"GoOS": "linux",
|
|
||||||
"Kernel": "Linux",
|
|
||||||
"Core": "22.04",
|
|
||||||
"Platform": "x86_64",
|
|
||||||
"OS": "Ubuntu",
|
|
||||||
"WtVersion": "development",
|
|
||||||
"UIVersion": ""
|
|
||||||
},
|
|
||||||
"Name": "crocodile",
|
|
||||||
"DNSLabel": "crocodile",
|
|
||||||
"Status": {
|
|
||||||
"LastSeen": "2023-02-13T12:37:12.635454796Z",
|
|
||||||
"Connected": true
|
|
||||||
},
|
|
||||||
"UserID": "edafee4e-63fb-11ec-90d6-0242ac120003",
|
|
||||||
"SSHKey": "AAAAC3NzaC1lZDI1NTE5AAAAIJN1NM4bpB9K",
|
|
||||||
"SSHEnabled": false
|
|
||||||
},
|
|
||||||
"cfeg6sf06sqkneg59g50": {
|
|
||||||
"ID": "cfeg6sf06sqkneg59g50",
|
|
||||||
"Key": "zMAOKUeIYIuun4n0xPR1b3IdYZPmsyjYmB2jWCuloC4=",
|
|
||||||
"SetupKey": "",
|
|
||||||
"IP": "100.103.26.180",
|
|
||||||
"Meta": {
|
|
||||||
"Hostname": "borg",
|
|
||||||
"GoOS": "linux",
|
|
||||||
"Kernel": "Linux",
|
|
||||||
"Core": "22.04",
|
|
||||||
"Platform": "x86_64",
|
|
||||||
"OS": "Ubuntu",
|
|
||||||
"WtVersion": "development",
|
|
||||||
"UIVersion": ""
|
|
||||||
},
|
|
||||||
"Name": "dingo",
|
|
||||||
"DNSLabel": "dingo",
|
|
||||||
"Status": {
|
|
||||||
"LastSeen": "2023-02-21T09:37:42.565899199Z",
|
|
||||||
"Connected": false
|
|
||||||
},
|
|
||||||
"UserID": "f4f6d672-63fb-11ec-90d6-0242ac120003",
|
|
||||||
"SSHKey": "AAAAC3NzaC1lZDI1NTE5AAAAILHW",
|
|
||||||
"SSHEnabled": true
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Groups": {
|
|
||||||
"cfefqs706sqkneg59g3g": {
|
|
||||||
"ID": "cfefqs706sqkneg59g3g",
|
|
||||||
"Name": "All",
|
|
||||||
"Peers": [
|
|
||||||
"cfefqs706sqkneg59g4g",
|
|
||||||
"cfeg6sf06sqkneg59g50"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Rules": {
|
|
||||||
"cfefqs706sqkneg59g40": {
|
|
||||||
"ID": "cfefqs706sqkneg59g40",
|
|
||||||
"Name": "Default",
|
|
||||||
"Description": "This is a default rule that allows connections between all the resources",
|
|
||||||
"Disabled": false,
|
|
||||||
"Source": [
|
|
||||||
"cfefqs706sqkneg59g3g"
|
|
||||||
],
|
|
||||||
"Destination": [
|
|
||||||
"cfefqs706sqkneg59g3g"
|
|
||||||
],
|
|
||||||
"Flow": 0
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Users": {
|
|
||||||
"edafee4e-63fb-11ec-90d6-0242ac120003": {
|
|
||||||
"Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
|
|
||||||
"Role": "admin"
|
|
||||||
},
|
|
||||||
"f4f6d672-63fb-11ec-90d6-0242ac120003": {
|
|
||||||
"Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
|
|
||||||
"Role": "user"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
BIN
management/server/testdata/store_policy_migrate.sqlite
vendored
Normal file
BIN
management/server/testdata/store_policy_migrate.sqlite
vendored
Normal file
Binary file not shown.
@ -1,130 +0,0 @@
|
|||||||
{
|
|
||||||
"Accounts": {
|
|
||||||
"bf1c8084-ba50-4ce7-9439-34653001fc3b": {
|
|
||||||
"Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
|
||||||
"Domain": "test.com",
|
|
||||||
"DomainCategory": "private",
|
|
||||||
"IsDomainPrimaryAccount": true,
|
|
||||||
"Settings": {
|
|
||||||
"PeerLoginExpirationEnabled": true,
|
|
||||||
"PeerLoginExpiration": 3600000000000
|
|
||||||
},
|
|
||||||
"SetupKeys": {
|
|
||||||
"A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
|
|
||||||
"Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
|
|
||||||
"Name": "Default key",
|
|
||||||
"Type": "reusable",
|
|
||||||
"CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
|
|
||||||
"ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
|
|
||||||
"Revoked": false,
|
|
||||||
"UsedTimes": 0
|
|
||||||
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Network": {
|
|
||||||
"Id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
|
|
||||||
"Net": {
|
|
||||||
"IP": "100.64.0.0",
|
|
||||||
"Mask": "//8AAA=="
|
|
||||||
},
|
|
||||||
"Dns": null
|
|
||||||
},
|
|
||||||
"Peers": {
|
|
||||||
"cfvprsrlo1hqoo49ohog": {
|
|
||||||
"ID": "cfvprsrlo1hqoo49ohog",
|
|
||||||
"Key": "5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=",
|
|
||||||
"SetupKey": "72546A29-6BC8-4311-BCFC-9CDBF33F1A48",
|
|
||||||
"IP": "100.64.114.31",
|
|
||||||
"Meta": {
|
|
||||||
"Hostname": "f2a34f6a4731",
|
|
||||||
"GoOS": "linux",
|
|
||||||
"Kernel": "Linux",
|
|
||||||
"Core": "11",
|
|
||||||
"Platform": "unknown",
|
|
||||||
"OS": "Debian GNU/Linux",
|
|
||||||
"WtVersion": "0.12.0",
|
|
||||||
"UIVersion": ""
|
|
||||||
},
|
|
||||||
"Name": "f2a34f6a4731",
|
|
||||||
"DNSLabel": "f2a34f6a4731",
|
|
||||||
"Status": {
|
|
||||||
"LastSeen": "2023-03-02T09:21:02.189035775+01:00",
|
|
||||||
"Connected": false,
|
|
||||||
"LoginExpired": false
|
|
||||||
},
|
|
||||||
"UserID": "",
|
|
||||||
"SSHKey": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk",
|
|
||||||
"SSHEnabled": false,
|
|
||||||
"LoginExpirationEnabled": true,
|
|
||||||
"LastLogin": "2023-03-01T19:48:19.817799698+01:00"
|
|
||||||
},
|
|
||||||
"cg05lnblo1hkg2j514p0": {
|
|
||||||
"ID": "cg05lnblo1hkg2j514p0",
|
|
||||||
"Key": "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=",
|
|
||||||
"SetupKey": "",
|
|
||||||
"IP": "100.64.39.54",
|
|
||||||
"Meta": {
|
|
||||||
"Hostname": "expiredhost",
|
|
||||||
"GoOS": "linux",
|
|
||||||
"Kernel": "Linux",
|
|
||||||
"Core": "22.04",
|
|
||||||
"Platform": "x86_64",
|
|
||||||
"OS": "Ubuntu",
|
|
||||||
"WtVersion": "development",
|
|
||||||
"UIVersion": ""
|
|
||||||
},
|
|
||||||
"Name": "expiredhost",
|
|
||||||
"DNSLabel": "expiredhost",
|
|
||||||
"Status": {
|
|
||||||
"LastSeen": "2023-03-02T09:19:57.276717255+01:00",
|
|
||||||
"Connected": false,
|
|
||||||
"LoginExpired": true
|
|
||||||
},
|
|
||||||
"UserID": "edafee4e-63fb-11ec-90d6-0242ac120003",
|
|
||||||
"SSHKey": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK",
|
|
||||||
"SSHEnabled": false,
|
|
||||||
"LoginExpirationEnabled": true,
|
|
||||||
"LastLogin": "2023-03-02T09:14:21.791679181+01:00"
|
|
||||||
},
|
|
||||||
"cg3161rlo1hs9cq94gdg": {
|
|
||||||
"ID": "cg3161rlo1hs9cq94gdg",
|
|
||||||
"Key": "mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=",
|
|
||||||
"SetupKey": "",
|
|
||||||
"IP": "100.64.117.96",
|
|
||||||
"Meta": {
|
|
||||||
"Hostname": "testhost",
|
|
||||||
"GoOS": "linux",
|
|
||||||
"Kernel": "Linux",
|
|
||||||
"Core": "22.04",
|
|
||||||
"Platform": "x86_64",
|
|
||||||
"OS": "Ubuntu",
|
|
||||||
"WtVersion": "development",
|
|
||||||
"UIVersion": ""
|
|
||||||
},
|
|
||||||
"Name": "testhost",
|
|
||||||
"DNSLabel": "testhost",
|
|
||||||
"Status": {
|
|
||||||
"LastSeen": "2023-03-06T18:21:27.252010027+01:00",
|
|
||||||
"Connected": false,
|
|
||||||
"LoginExpired": false
|
|
||||||
},
|
|
||||||
"UserID": "edafee4e-63fb-11ec-90d6-0242ac120003",
|
|
||||||
"SSHKey": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM",
|
|
||||||
"SSHEnabled": false,
|
|
||||||
"LoginExpirationEnabled": false,
|
|
||||||
"LastLogin": "2023-03-07T09:02:47.442857106+01:00"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Users": {
|
|
||||||
"edafee4e-63fb-11ec-90d6-0242ac120003": {
|
|
||||||
"Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
|
|
||||||
"Role": "admin"
|
|
||||||
},
|
|
||||||
"f4f6d672-63fb-11ec-90d6-0242ac120003": {
|
|
||||||
"Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
|
|
||||||
"Role": "user"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
BIN
management/server/testdata/store_with_expired_peers.sqlite
vendored
Normal file
BIN
management/server/testdata/store_with_expired_peers.sqlite
vendored
Normal file
Binary file not shown.
154
management/server/testdata/storev1.json
vendored
154
management/server/testdata/storev1.json
vendored
@ -1,154 +0,0 @@
|
|||||||
{
|
|
||||||
"Accounts": {
|
|
||||||
"auth0|61bf82ddeab084006aa1bccd": {
|
|
||||||
"Id": "auth0|61bf82ddeab084006aa1bccd",
|
|
||||||
"SetupKeys": {
|
|
||||||
"1B2B50B0-B3E8-4B0C-A426-525EDB8481BD": {
|
|
||||||
"Id": "831727121",
|
|
||||||
"Key": "1B2B50B0-B3E8-4B0C-A426-525EDB8481BD",
|
|
||||||
"Name": "One-off key",
|
|
||||||
"Type": "one-off",
|
|
||||||
"CreatedAt": "2021-12-24T16:09:45.926075752+01:00",
|
|
||||||
"ExpiresAt": "2022-01-23T16:09:45.926075752+01:00",
|
|
||||||
"Revoked": false,
|
|
||||||
"UsedTimes": 1,
|
|
||||||
"LastUsed": "2021-12-24T16:12:45.763424077+01:00"
|
|
||||||
},
|
|
||||||
"EB51E9EB-A11F-4F6E-8E49-C982891B405A": {
|
|
||||||
"Id": "1769568301",
|
|
||||||
"Key": "EB51E9EB-A11F-4F6E-8E49-C982891B405A",
|
|
||||||
"Name": "Default key",
|
|
||||||
"Type": "reusable",
|
|
||||||
"CreatedAt": "2021-12-24T16:09:45.926073628+01:00",
|
|
||||||
"ExpiresAt": "2022-01-23T16:09:45.926073628+01:00",
|
|
||||||
"Revoked": false,
|
|
||||||
"UsedTimes": 1,
|
|
||||||
"LastUsed": "2021-12-24T16:13:06.236748538+01:00"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Network": {
|
|
||||||
"Id": "a443c07a-5765-4a78-97fc-390d9c1d0e49",
|
|
||||||
"Net": {
|
|
||||||
"IP": "100.64.0.0",
|
|
||||||
"Mask": "/8AAAA=="
|
|
||||||
},
|
|
||||||
"Dns": ""
|
|
||||||
},
|
|
||||||
"Peers": {
|
|
||||||
"oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=": {
|
|
||||||
"Key": "oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=",
|
|
||||||
"SetupKey": "EB51E9EB-A11F-4F6E-8E49-C982891B405A",
|
|
||||||
"IP": "100.64.0.2",
|
|
||||||
"Meta": {
|
|
||||||
"Hostname": "braginini",
|
|
||||||
"GoOS": "linux",
|
|
||||||
"Kernel": "Linux",
|
|
||||||
"Core": "21.04",
|
|
||||||
"Platform": "x86_64",
|
|
||||||
"OS": "Ubuntu",
|
|
||||||
"WtVersion": ""
|
|
||||||
},
|
|
||||||
"Name": "braginini",
|
|
||||||
"Status": {
|
|
||||||
"LastSeen": "2021-12-24T16:13:11.244342541+01:00",
|
|
||||||
"Connected": false
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=": {
|
|
||||||
"Key": "xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=",
|
|
||||||
"SetupKey": "1B2B50B0-B3E8-4B0C-A426-525EDB8481BD",
|
|
||||||
"IP": "100.64.0.1",
|
|
||||||
"Meta": {
|
|
||||||
"Hostname": "braginini",
|
|
||||||
"GoOS": "linux",
|
|
||||||
"Kernel": "Linux",
|
|
||||||
"Core": "21.04",
|
|
||||||
"Platform": "x86_64",
|
|
||||||
"OS": "Ubuntu",
|
|
||||||
"WtVersion": ""
|
|
||||||
},
|
|
||||||
"Name": "braginini",
|
|
||||||
"Status": {
|
|
||||||
"LastSeen": "2021-12-24T16:12:49.089339333+01:00",
|
|
||||||
"Connected": false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"google-oauth2|103201118415301331038": {
|
|
||||||
"Id": "google-oauth2|103201118415301331038",
|
|
||||||
"SetupKeys": {
|
|
||||||
"5AFB60DB-61F2-4251-8E11-494847EE88E9": {
|
|
||||||
"Id": "2485964613",
|
|
||||||
"Key": "5AFB60DB-61F2-4251-8E11-494847EE88E9",
|
|
||||||
"Name": "Default key",
|
|
||||||
"Type": "reusable",
|
|
||||||
"CreatedAt": "2021-12-24T16:10:02.238476+01:00",
|
|
||||||
"ExpiresAt": "2022-01-23T16:10:02.238476+01:00",
|
|
||||||
"Revoked": false,
|
|
||||||
"UsedTimes": 1,
|
|
||||||
"LastUsed": "2021-12-24T16:12:05.994307717+01:00"
|
|
||||||
},
|
|
||||||
"A72E4DC2-00DE-4542-8A24-62945438104E": {
|
|
||||||
"Id": "3504804807",
|
|
||||||
"Key": "A72E4DC2-00DE-4542-8A24-62945438104E",
|
|
||||||
"Name": "One-off key",
|
|
||||||
"Type": "one-off",
|
|
||||||
"CreatedAt": "2021-12-24T16:10:02.238478209+01:00",
|
|
||||||
"ExpiresAt": "2022-01-23T16:10:02.238478209+01:00",
|
|
||||||
"Revoked": false,
|
|
||||||
"UsedTimes": 1,
|
|
||||||
"LastUsed": "2021-12-24T16:11:27.015741738+01:00"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Network": {
|
|
||||||
"Id": "b6d0b152-364e-40c1-a8a1-fa7bcac2267f",
|
|
||||||
"Net": {
|
|
||||||
"IP": "100.64.0.0",
|
|
||||||
"Mask": "/8AAAA=="
|
|
||||||
},
|
|
||||||
"Dns": ""
|
|
||||||
},
|
|
||||||
"Peers": {
|
|
||||||
"6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=": {
|
|
||||||
"Key": "6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=",
|
|
||||||
"SetupKey": "5AFB60DB-61F2-4251-8E11-494847EE88E9",
|
|
||||||
"IP": "100.64.0.2",
|
|
||||||
"Meta": {
|
|
||||||
"Hostname": "braginini",
|
|
||||||
"GoOS": "linux",
|
|
||||||
"Kernel": "Linux",
|
|
||||||
"Core": "21.04",
|
|
||||||
"Platform": "x86_64",
|
|
||||||
"OS": "Ubuntu",
|
|
||||||
"WtVersion": ""
|
|
||||||
},
|
|
||||||
"Name": "braginini",
|
|
||||||
"Status": {
|
|
||||||
"LastSeen": "2021-12-24T16:12:05.994305438+01:00",
|
|
||||||
"Connected": false
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=": {
|
|
||||||
"Key": "Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=",
|
|
||||||
"SetupKey": "A72E4DC2-00DE-4542-8A24-62945438104E",
|
|
||||||
"IP": "100.64.0.1",
|
|
||||||
"Meta": {
|
|
||||||
"Hostname": "braginini",
|
|
||||||
"GoOS": "linux",
|
|
||||||
"Kernel": "Linux",
|
|
||||||
"Core": "21.04",
|
|
||||||
"Platform": "x86_64",
|
|
||||||
"OS": "Ubuntu",
|
|
||||||
"WtVersion": ""
|
|
||||||
},
|
|
||||||
"Name": "braginini",
|
|
||||||
"Status": {
|
|
||||||
"LastSeen": "2021-12-24T16:11:27.015739803+01:00",
|
|
||||||
"Connected": false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
BIN
management/server/testdata/storev1.sqlite
vendored
Normal file
BIN
management/server/testdata/storev1.sqlite
vendored
Normal file
Binary file not shown.
@ -9,14 +9,14 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
"github.com/netbirdio/netbird/management/server/integration_reference"
|
"github.com/netbirdio/netbird/management/server/integration_reference"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -1274,6 +1274,74 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
|
|||||||
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, hadPeers, nil
|
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, hadPeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
|
||||||
|
func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd,
|
||||||
|
groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) {
|
||||||
|
|
||||||
|
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userPeerIDMap := make(map[string]struct{}, len(peers))
|
||||||
|
for _, peer := range peers {
|
||||||
|
userPeerIDMap[peer.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, gid := range groupsToAdd {
|
||||||
|
group, ok := accountGroups[gid]
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("group not found")
|
||||||
|
}
|
||||||
|
addUserPeersToGroup(userPeerIDMap, group)
|
||||||
|
groupsToUpdate = append(groupsToUpdate, group)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, gid := range groupsToRemove {
|
||||||
|
group, ok := accountGroups[gid]
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("group not found")
|
||||||
|
}
|
||||||
|
removeUserPeersFromGroup(userPeerIDMap, group)
|
||||||
|
groupsToUpdate = append(groupsToUpdate, group)
|
||||||
|
}
|
||||||
|
|
||||||
|
return groupsToUpdate, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addUserPeersToGroup adds the user's peers to the group.
|
||||||
|
func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) {
|
||||||
|
groupPeers := make(map[string]struct{}, len(group.Peers))
|
||||||
|
for _, pid := range group.Peers {
|
||||||
|
groupPeers[pid] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for pid := range userPeerIDs {
|
||||||
|
groupPeers[pid] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
group.Peers = make([]string, 0, len(groupPeers))
|
||||||
|
for pid := range groupPeers {
|
||||||
|
group.Peers = append(group.Peers, pid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeUserPeersFromGroup removes user's peers from the group.
|
||||||
|
func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) {
|
||||||
|
// skip removing peers from group All
|
||||||
|
if group.Name == "All" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedPeers := make([]string, 0, len(group.Peers))
|
||||||
|
for _, pid := range group.Peers {
|
||||||
|
if _, found := userPeerIDs[pid]; !found {
|
||||||
|
updatedPeers = append(updatedPeers, pid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
group.Peers = updatedPeers
|
||||||
|
}
|
||||||
|
|
||||||
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
|
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
|
||||||
for _, user := range userData {
|
for _, user := range userData {
|
||||||
if user.ID == userID {
|
if user.ID == userID {
|
||||||
|
@ -62,8 +62,10 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, pat.CreatedBy, mockUserID)
|
assert.Equal(t, pat.CreatedBy, mockUserID)
|
||||||
|
|
||||||
fileStore := am.Store.(*FileStore)
|
tokenID, err := am.Store.GetTokenIDByHashedToken(context.Background(), pat.HashedToken)
|
||||||
tokenID := fileStore.HashedPAT2TokenID[pat.HashedToken]
|
if err != nil {
|
||||||
|
t.Fatalf("Error when getting token ID by hashed token: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
if tokenID == "" {
|
if tokenID == "" {
|
||||||
t.Fatal("GetTokenIDByHashedToken failed after adding PAT")
|
t.Fatal("GetTokenIDByHashedToken failed after adding PAT")
|
||||||
@ -71,11 +73,12 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, pat.ID, tokenID)
|
assert.Equal(t, pat.ID, tokenID)
|
||||||
|
|
||||||
userID := fileStore.TokenID2UserID[tokenID]
|
user, err := am.Store.GetUserByTokenID(context.Background(), tokenID)
|
||||||
if userID == "" {
|
if err != nil {
|
||||||
t.Fatal("GetUserByTokenId failed after adding PAT")
|
t.Fatalf("Error when getting user by token ID: %s", err)
|
||||||
}
|
}
|
||||||
assert.Equal(t, mockUserID, userID)
|
|
||||||
|
assert.Equal(t, mockUserID, user.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||||
@ -192,9 +195,12 @@ func TestUser_DeletePAT(t *testing.T) {
|
|||||||
t.Fatalf("Error when adding PAT to user: %s", err)
|
t.Fatalf("Error when adding PAT to user: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Nil(t, store.Accounts[mockAccountID].Users[mockUserID].PATs[mockTokenID1])
|
account, err = store.GetAccount(context.Background(), mockAccountID)
|
||||||
assert.Empty(t, store.HashedPAT2TokenID[mockToken1])
|
if err != nil {
|
||||||
assert.Empty(t, store.TokenID2UserID[mockTokenID1])
|
t.Fatalf("Error when getting account: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Nil(t, account.Users[mockUserID].PATs[mockTokenID1])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUser_GetPAT(t *testing.T) {
|
func TestUser_GetPAT(t *testing.T) {
|
||||||
@ -353,13 +359,16 @@ func TestUser_CreateServiceUser(t *testing.T) {
|
|||||||
t.Fatalf("Error when creating service user: %s", err)
|
t.Fatalf("Error when creating service user: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users))
|
account, err = store.GetAccount(context.Background(), mockAccountID)
|
||||||
assert.NotNil(t, store.Accounts[mockAccountID].Users[user.ID])
|
assert.NoError(t, err)
|
||||||
assert.True(t, store.Accounts[mockAccountID].Users[user.ID].IsServiceUser)
|
|
||||||
assert.Equal(t, mockServiceUserName, store.Accounts[mockAccountID].Users[user.ID].ServiceUserName)
|
assert.Equal(t, 2, len(account.Users))
|
||||||
assert.Equal(t, UserRole(mockRole), store.Accounts[mockAccountID].Users[user.ID].Role)
|
assert.NotNil(t, account.Users[user.ID])
|
||||||
assert.Equal(t, []string{"group1", "group2"}, store.Accounts[mockAccountID].Users[user.ID].AutoGroups)
|
assert.True(t, account.Users[user.ID].IsServiceUser)
|
||||||
assert.Equal(t, map[string]*PersonalAccessToken{}, store.Accounts[mockAccountID].Users[user.ID].PATs)
|
assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName)
|
||||||
|
assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role)
|
||||||
|
assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups)
|
||||||
|
assert.Equal(t, map[string]*PersonalAccessToken{}, account.Users[user.ID].PATs)
|
||||||
|
|
||||||
assert.Zero(t, user.Email)
|
assert.Zero(t, user.Email)
|
||||||
assert.True(t, user.IsServiceUser)
|
assert.True(t, user.IsServiceUser)
|
||||||
@ -397,12 +406,15 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
|||||||
t.Fatalf("Error when creating user: %s", err)
|
t.Fatalf("Error when creating user: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account, err = store.GetAccount(context.Background(), mockAccountID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, user.IsServiceUser)
|
assert.True(t, user.IsServiceUser)
|
||||||
assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users))
|
assert.Equal(t, 2, len(account.Users))
|
||||||
assert.True(t, store.Accounts[mockAccountID].Users[user.ID].IsServiceUser)
|
assert.True(t, account.Users[user.ID].IsServiceUser)
|
||||||
assert.Equal(t, mockServiceUserName, store.Accounts[mockAccountID].Users[user.ID].ServiceUserName)
|
assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName)
|
||||||
assert.Equal(t, UserRole(mockRole), store.Accounts[mockAccountID].Users[user.ID].Role)
|
assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role)
|
||||||
assert.Equal(t, []string{"group1", "group2"}, store.Accounts[mockAccountID].Users[user.ID].AutoGroups)
|
assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups)
|
||||||
|
|
||||||
assert.Equal(t, mockServiceUserName, user.Name)
|
assert.Equal(t, mockServiceUserName, user.Name)
|
||||||
assert.Equal(t, mockRole, user.Role)
|
assert.Equal(t, mockRole, user.Role)
|
||||||
@ -553,12 +565,15 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
|||||||
err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockServiceUserID)
|
err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockServiceUserID)
|
||||||
tt.assertErrFunc(t, err, tt.assertErrMessage)
|
tt.assertErrFunc(t, err, tt.assertErrMessage)
|
||||||
|
|
||||||
|
account, err2 := store.GetAccount(context.Background(), mockAccountID)
|
||||||
|
assert.NoError(t, err2)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
assert.Equal(t, 2, len(store.Accounts[mockAccountID].Users))
|
assert.Equal(t, 2, len(account.Users))
|
||||||
assert.NotNil(t, store.Accounts[mockAccountID].Users[mockServiceUserID])
|
assert.NotNil(t, account.Users[mockServiceUserID])
|
||||||
} else {
|
} else {
|
||||||
assert.Equal(t, 1, len(store.Accounts[mockAccountID].Users))
|
assert.Equal(t, 1, len(account.Users))
|
||||||
assert.Nil(t, store.Accounts[mockAccountID].Users[mockServiceUserID])
|
assert.Nil(t, account.Users[mockServiceUserID])
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -801,10 +816,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "")
|
acc, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
acc, err := am.Store.GetAccount(context.Background(), accID)
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
for _, id := range tc.expectedDeleted {
|
for _, id := range tc.expectedDeleted {
|
||||||
|
53
util/file.go
53
util/file.go
@ -1,11 +1,15 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
@ -160,6 +164,55 @@ func ReadJson(file string, res interface{}) (interface{}, error) {
|
|||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReadJsonWithEnvSub reads JSON config file and maps to a provided interface with environment variable substitution
|
||||||
|
func ReadJsonWithEnvSub(file string, res interface{}) (interface{}, error) {
|
||||||
|
envVars := getEnvMap()
|
||||||
|
|
||||||
|
f, err := os.Open(file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
bs, err := io.ReadAll(f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
t, err := template.New("").Parse(string(bs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var output bytes.Buffer
|
||||||
|
// Execute the template, substituting environment variables
|
||||||
|
err = t.Execute(&output, envVars)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error executing template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.Unmarshal(output.Bytes(), &res)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed parsing Json file after template was executed, err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEnvMap Convert the output of os.Environ() to a map
|
||||||
|
func getEnvMap() map[string]string {
|
||||||
|
envMap := make(map[string]string)
|
||||||
|
|
||||||
|
for _, env := range os.Environ() {
|
||||||
|
parts := strings.SplitN(env, "=", 2)
|
||||||
|
if len(parts) == 2 {
|
||||||
|
envMap[parts[0]] = parts[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return envMap
|
||||||
|
}
|
||||||
|
|
||||||
// CopyFileContents copies contents of the given src file to the dst file
|
// CopyFileContents copies contents of the given src file to the dst file
|
||||||
func CopyFileContents(src, dst string) (err error) {
|
func CopyFileContents(src, dst string) (err error) {
|
||||||
in, err := os.Open(src)
|
in, err := os.Open(src)
|
||||||
|
126
util/file_suite_test.go
Normal file
126
util/file_suite_test.go
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
package util_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/md5"
|
||||||
|
"encoding/hex"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Client", func() {
|
||||||
|
|
||||||
|
var (
|
||||||
|
tmpDir string
|
||||||
|
)
|
||||||
|
|
||||||
|
type TestConfig struct {
|
||||||
|
SomeMap map[string]string
|
||||||
|
SomeArray []string
|
||||||
|
SomeField int
|
||||||
|
}
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
var err error
|
||||||
|
tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*")
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
err := os.RemoveAll(tmpDir)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("Config", func() {
|
||||||
|
Context("in JSON format", func() {
|
||||||
|
It("should be written and read successfully", func() {
|
||||||
|
|
||||||
|
m := make(map[string]string)
|
||||||
|
m["key1"] = "value1"
|
||||||
|
m["key2"] = "value2"
|
||||||
|
|
||||||
|
arr := []string{"value1", "value2"}
|
||||||
|
|
||||||
|
written := &TestConfig{
|
||||||
|
SomeMap: m,
|
||||||
|
SomeArray: arr,
|
||||||
|
SomeField: 99,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := util.WriteJson(tmpDir+"/testconfig.json", written)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(read).NotTo(BeNil())
|
||||||
|
Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"]))
|
||||||
|
Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"]))
|
||||||
|
Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr))
|
||||||
|
Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField))
|
||||||
|
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("Copying file contents", func() {
|
||||||
|
Context("from one file to another", func() {
|
||||||
|
It("should be successful", func() {
|
||||||
|
|
||||||
|
src := tmpDir + "/copytest_src"
|
||||||
|
dst := tmpDir + "/copytest_dst"
|
||||||
|
|
||||||
|
err := util.WriteJson(src, []string{"1", "2", "3"})
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
err = util.CopyFileContents(src, dst)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
hashSrc := md5.New()
|
||||||
|
hashDst := md5.New()
|
||||||
|
|
||||||
|
srcFile, err := os.Open(src)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
dstFile, err := os.Open(dst)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
_, err = io.Copy(hashSrc, srcFile)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
_, err = io.Copy(hashDst, dstFile)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
err = srcFile.Close()
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
err = dstFile.Close()
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16])))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Describe("Handle config file without full path", func() {
|
||||||
|
Context("config file handling", func() {
|
||||||
|
It("should be successful", func() {
|
||||||
|
written := &TestConfig{
|
||||||
|
SomeField: 123,
|
||||||
|
}
|
||||||
|
cfgFile := "test_cfg.json"
|
||||||
|
defer os.Remove(cfgFile)
|
||||||
|
|
||||||
|
err := util.WriteJson(cfgFile, written)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
|
read, err := util.ReadJson(cfgFile, &TestConfig{})
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(read).NotTo(BeNil())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
@ -1,126 +1,198 @@
|
|||||||
package util_test
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/md5"
|
|
||||||
"encoding/hex"
|
|
||||||
"io"
|
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
. "github.com/onsi/ginkgo"
|
"strings"
|
||||||
. "github.com/onsi/gomega"
|
"testing"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("Client", func() {
|
func TestReadJsonWithEnvSub(t *testing.T) {
|
||||||
|
type Config struct {
|
||||||
var (
|
CertFile string `json:"CertFile"`
|
||||||
tmpDir string
|
Credentials string `json:"Credentials"`
|
||||||
)
|
NestedOption struct {
|
||||||
|
URL string `json:"URL"`
|
||||||
type TestConfig struct {
|
} `json:"NestedOption"`
|
||||||
SomeMap map[string]string
|
|
||||||
SomeArray []string
|
|
||||||
SomeField int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BeforeEach(func() {
|
type testCase struct {
|
||||||
var err error
|
name string
|
||||||
tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*")
|
envVars map[string]string
|
||||||
Expect(err).NotTo(HaveOccurred())
|
jsonTemplate string
|
||||||
})
|
expectedResult Config
|
||||||
|
expectError bool
|
||||||
|
errorContains string
|
||||||
|
}
|
||||||
|
|
||||||
AfterEach(func() {
|
tests := []testCase{
|
||||||
err := os.RemoveAll(tmpDir)
|
{
|
||||||
Expect(err).NotTo(HaveOccurred())
|
name: "All environment variables set",
|
||||||
})
|
envVars: map[string]string{
|
||||||
|
"CERT_FILE": "/etc/certs/env_cert.crt",
|
||||||
|
"CREDENTIALS": "env_credentials",
|
||||||
|
"URL": "https://env.testing.com",
|
||||||
|
},
|
||||||
|
jsonTemplate: `{
|
||||||
|
"CertFile": "{{ .CERT_FILE }}",
|
||||||
|
"Credentials": "{{ .CREDENTIALS }}",
|
||||||
|
"NestedOption": {
|
||||||
|
"URL": "{{ .URL }}"
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
expectedResult: Config{
|
||||||
|
CertFile: "/etc/certs/env_cert.crt",
|
||||||
|
Credentials: "env_credentials",
|
||||||
|
NestedOption: struct {
|
||||||
|
URL string `json:"URL"`
|
||||||
|
}{
|
||||||
|
URL: "https://env.testing.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Missing environment variable",
|
||||||
|
envVars: map[string]string{
|
||||||
|
"CERT_FILE": "/etc/certs/env_cert.crt",
|
||||||
|
"CREDENTIALS": "env_credentials",
|
||||||
|
// "URL" is intentionally missing
|
||||||
|
},
|
||||||
|
jsonTemplate: `{
|
||||||
|
"CertFile": "{{ .CERT_FILE }}",
|
||||||
|
"Credentials": "{{ .CREDENTIALS }}",
|
||||||
|
"NestedOption": {
|
||||||
|
"URL": "{{ .URL }}"
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
expectedResult: Config{
|
||||||
|
CertFile: "/etc/certs/env_cert.crt",
|
||||||
|
Credentials: "env_credentials",
|
||||||
|
NestedOption: struct {
|
||||||
|
URL string `json:"URL"`
|
||||||
|
}{
|
||||||
|
URL: "<no value>",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid JSON template",
|
||||||
|
envVars: map[string]string{
|
||||||
|
"CERT_FILE": "/etc/certs/env_cert.crt",
|
||||||
|
"CREDENTIALS": "env_credentials",
|
||||||
|
"URL": "https://env.testing.com",
|
||||||
|
},
|
||||||
|
jsonTemplate: `{
|
||||||
|
"CertFile": "{{ .CERT_FILE }}",
|
||||||
|
"Credentials": "{{ .CREDENTIALS }",
|
||||||
|
"NestedOption": {
|
||||||
|
"URL": "{{ .URL }}"
|
||||||
|
}
|
||||||
|
}`, // Note the missing closing brace in "{{ .CREDENTIALS }"
|
||||||
|
expectedResult: Config{},
|
||||||
|
expectError: true,
|
||||||
|
errorContains: "unexpected \"}\" in operand",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No substitutions",
|
||||||
|
envVars: map[string]string{
|
||||||
|
"CERT_FILE": "/etc/certs/env_cert.crt",
|
||||||
|
"CREDENTIALS": "env_credentials",
|
||||||
|
"URL": "https://env.testing.com",
|
||||||
|
},
|
||||||
|
jsonTemplate: `{
|
||||||
|
"CertFile": "/etc/certs/cert.crt",
|
||||||
|
"Credentials": "admnlknflkdasdf",
|
||||||
|
"NestedOption" : {
|
||||||
|
"URL": "https://testing.com"
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
expectedResult: Config{
|
||||||
|
CertFile: "/etc/certs/cert.crt",
|
||||||
|
Credentials: "admnlknflkdasdf",
|
||||||
|
NestedOption: struct {
|
||||||
|
URL string `json:"URL"`
|
||||||
|
}{
|
||||||
|
URL: "https://testing.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should fail when Invalid characters in variables",
|
||||||
|
envVars: map[string]string{
|
||||||
|
"CERT_FILE": `"/etc/certs/"cert".crt"`,
|
||||||
|
"CREDENTIALS": `env_credentia{ls}`,
|
||||||
|
"URL": `https://env.testing.com?param={{value}}`,
|
||||||
|
},
|
||||||
|
jsonTemplate: `{
|
||||||
|
"CertFile": "{{ .CERT_FILE }}",
|
||||||
|
"Credentials": "{{ .CREDENTIALS }}",
|
||||||
|
"NestedOption": {
|
||||||
|
"URL": "{{ .URL }}"
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
expectedResult: Config{
|
||||||
|
CertFile: `"/etc/certs/"cert".crt"`,
|
||||||
|
Credentials: `env_credentia{ls}`,
|
||||||
|
NestedOption: struct {
|
||||||
|
URL string `json:"URL"`
|
||||||
|
}{
|
||||||
|
URL: `https://env.testing.com?param={{value}}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
Describe("Config", func() {
|
for _, tc := range tests {
|
||||||
Context("in JSON format", func() {
|
tc := tc
|
||||||
It("should be written and read successfully", func() {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
for key, value := range tc.envVars {
|
||||||
|
t.Setenv(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
m := make(map[string]string)
|
tempFile, err := os.CreateTemp("", "config*.json")
|
||||||
m["key1"] = "value1"
|
if err != nil {
|
||||||
m["key2"] = "value2"
|
t.Fatalf("Failed to create temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
arr := []string{"value1", "value2"}
|
defer func() {
|
||||||
|
err = os.Remove(tempFile.Name())
|
||||||
written := &TestConfig{
|
if err != nil {
|
||||||
SomeMap: m,
|
t.Logf("Failed to remove temp file: %v", err)
|
||||||
SomeArray: arr,
|
|
||||||
SomeField: 99,
|
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
err := util.WriteJson(tmpDir+"/testconfig.json", written)
|
_, err = tempFile.WriteString(tc.jsonTemplate)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to write to temp file: %v", err)
|
||||||
|
}
|
||||||
|
err = tempFile.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to close temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
|
var result Config
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
Expect(read).NotTo(BeNil())
|
|
||||||
Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"]))
|
|
||||||
Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"]))
|
|
||||||
Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr))
|
|
||||||
Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField))
|
|
||||||
|
|
||||||
})
|
_, err = ReadJsonWithEnvSub(tempFile.Name(), &result)
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("Copying file contents", func() {
|
if tc.expectError {
|
||||||
Context("from one file to another", func() {
|
if err == nil {
|
||||||
It("should be successful", func() {
|
t.Fatalf("Expected error but got none")
|
||||||
|
|
||||||
src := tmpDir + "/copytest_src"
|
|
||||||
dst := tmpDir + "/copytest_dst"
|
|
||||||
|
|
||||||
err := util.WriteJson(src, []string{"1", "2", "3"})
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
err = util.CopyFileContents(src, dst)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
hashSrc := md5.New()
|
|
||||||
hashDst := md5.New()
|
|
||||||
|
|
||||||
srcFile, err := os.Open(src)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
dstFile, err := os.Open(dst)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
_, err = io.Copy(hashSrc, srcFile)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
_, err = io.Copy(hashDst, dstFile)
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
err = srcFile.Close()
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
err = dstFile.Close()
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
|
||||||
|
|
||||||
Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16])))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
Describe("Handle config file without full path", func() {
|
|
||||||
Context("config file handling", func() {
|
|
||||||
It("should be successful", func() {
|
|
||||||
written := &TestConfig{
|
|
||||||
SomeField: 123,
|
|
||||||
}
|
}
|
||||||
cfgFile := "test_cfg.json"
|
if !strings.Contains(err.Error(), tc.errorContains) {
|
||||||
defer os.Remove(cfgFile)
|
t.Errorf("Expected error containing '%s', but got '%v'", tc.errorContains, err)
|
||||||
|
}
|
||||||
err := util.WriteJson(cfgFile, written)
|
} else {
|
||||||
Expect(err).NotTo(HaveOccurred())
|
if err != nil {
|
||||||
|
t.Fatalf("ReadJsonWithEnvSub failed: %v", err)
|
||||||
read, err := util.ReadJson(cfgFile, &TestConfig{})
|
}
|
||||||
Expect(err).NotTo(HaveOccurred())
|
if !reflect.DeepEqual(result, tc.expectedResult) {
|
||||||
Expect(read).NotTo(BeNil())
|
t.Errorf("Result does not match expected.\nGot: %+v\nExpected: %+v", result, tc.expectedResult)
|
||||||
})
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
})
|
}
|
||||||
})
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user