[management] Add API of new network concept (#3012)

This commit is contained in:
Pascal Fischer
2024-12-11 12:58:45 +01:00
committed by GitHub
parent 9f859a240e
commit 60ee31df3e
92 changed files with 5320 additions and 3562 deletions

View File

@ -10,6 +10,7 @@ import (
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util"
@ -71,7 +72,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
if err != nil {
t.Fatal(err)
}

View File

@ -39,6 +39,7 @@ import (
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
@ -1196,7 +1197,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
if err != nil {
return nil, "", err
}

View File

@ -20,6 +20,7 @@ import (
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
@ -110,7 +111,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
if err != nil {
return nil, "", err
}

View File

@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/client/system"
@ -57,7 +58,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir())
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}

View File

@ -46,6 +46,7 @@ import (
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
@ -150,7 +151,7 @@ var (
if err != nil {
return err
}
store, err := server.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics)
store, err := store.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics)
if err != nil {
return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err)
}
@ -400,7 +401,7 @@ func notifyStop(ctx context.Context, msg string) {
}
}
func getInstallationID(ctx context.Context, store server.Store) (string, error) {
func getInstallationID(ctx context.Context, store store.Store) (string, error) {
installationID := store.GetInstallationID()
if installationID != "" {
return installationID, nil

View File

@ -9,7 +9,7 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
)
@ -32,7 +32,7 @@ var upCmd = &cobra.Command{
//nolint
ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource)
if err := server.MigrateFileStoreToSqlite(ctx, mgmtDataDir); err != nil {
if err := store.MigrateFileStoreToSqlite(ctx, mgmtDataDir); err != nil {
return err
}
log.WithContext(ctx).Info("Migration finished successfully")

File diff suppressed because it is too large Load Diff

View File

@ -7,6 +7,9 @@ import (
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
// AccountRequest holds the result channel to return the requested account.
@ -17,19 +20,19 @@ type AccountRequest struct {
// AccountResult holds the account data or an error.
type AccountResult struct {
Account *Account
Account *types.Account
Err error
}
type AccountRequestBuffer struct {
store Store
store store.Store
getAccountRequests map[string][]*AccountRequest
mu sync.Mutex
getAccountRequestCh chan *AccountRequest
bufferInterval time.Duration
}
func NewAccountRequestBuffer(ctx context.Context, store Store) *AccountRequestBuffer {
func NewAccountRequestBuffer(ctx context.Context, store store.Store) *AccountRequestBuffer {
bufferIntervalStr := os.Getenv("NB_GET_ACCOUNT_BUFFER_INTERVAL")
bufferInterval, err := time.ParseDuration(bufferIntervalStr)
if err != nil {
@ -52,7 +55,7 @@ func NewAccountRequestBuffer(ctx context.Context, store Store) *AccountRequestBu
return &ac
}
func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, accountID string) (*Account, error) {
func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) {
req := &AccountRequest{
AccountID: accountID,
ResultChan: make(chan *AccountResult, 1),

View File

@ -16,7 +16,11 @@ import (
"time"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server/networks"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -29,7 +33,9 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
@ -74,7 +80,7 @@ func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)
func (MocIntegratedValidator) Stop(_ context.Context) {
}
func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) {
func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *types.Account, userID string) {
t.Helper()
peer := &nbpeer.Peer{
Key: "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=",
@ -102,7 +108,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Ac
}
}
func verifyNewAccountHasDefaultFields(t *testing.T, account *Account, createdBy string, domain string, expectedUsers []string) {
func verifyNewAccountHasDefaultFields(t *testing.T, account *types.Account, createdBy string, domain string, expectedUsers []string) {
t.Helper()
if len(account.Peers) != 0 {
t.Errorf("expected account to have len(Peers) = %v, got %v", 0, len(account.Peers))
@ -157,7 +163,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
// peerID3 := "peer-3"
tt := []struct {
name string
accountSettings Settings
accountSettings types.Settings
peerID string
expectedPeers []string
expectedOfflinePeers []string
@ -165,7 +171,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}{
{
name: "Should return ALL peers when global peer login expiration disabled",
accountSettings: Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour},
accountSettings: types.Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour},
peerID: peerID1,
expectedPeers: []string{peerID2},
expectedOfflinePeers: []string{},
@ -203,7 +209,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
},
{
name: "Should return no peers when global peer login expiration enabled and peers expired",
accountSettings: Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour},
accountSettings: types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour},
peerID: peerID1,
expectedPeers: []string{},
expectedOfflinePeers: []string{peerID2},
@ -397,12 +403,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
netIP := net.IP{100, 64, 0, 0}
netMask := net.IPMask{255, 255, 0, 0}
network := &Network{
network := &types.Network{
Identifier: "network",
Net: net.IPNet{IP: netIP, Mask: netMask},
Dns: "netbird.selfhosted",
Serial: 0,
mu: sync.Mutex{},
Mu: sync.Mutex{},
}
for _, testCase := range tt {
@ -486,12 +492,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
}
initUnknown := defaultInitAccount
initUnknown.DomainCategory = UnknownCategory
initUnknown.DomainCategory = types.UnknownCategory
initUnknown.Domain = unknownDomain
privateInitAccount := defaultInitAccount
privateInitAccount.Domain = privateDomain
privateInitAccount.DomainCategory = PrivateCategory
privateInitAccount.DomainCategory = types.PrivateCategory
testCases := []struct {
name string
@ -501,7 +507,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
inputUpdateClaimAccount bool
testingFunc require.ComparisonAssertionFunc
expectedMSG string
expectedUserRole UserRole
expectedUserRole types.UserRole
expectedDomainCategory string
expectedDomain string
expectedPrimaryDomainStatus bool
@ -513,12 +519,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
inputClaims: jwtclaims.AuthorizationClaims{
Domain: publicDomain,
UserId: "pub-domain-user",
DomainCategory: PublicCategory,
DomainCategory: types.PublicCategory,
},
inputInitUserParams: defaultInitAccount,
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleOwner,
expectedUserRole: types.UserRoleOwner,
expectedDomainCategory: "",
expectedDomain: publicDomain,
expectedPrimaryDomainStatus: false,
@ -530,12 +536,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
inputClaims: jwtclaims.AuthorizationClaims{
Domain: unknownDomain,
UserId: "unknown-domain-user",
DomainCategory: UnknownCategory,
DomainCategory: types.UnknownCategory,
},
inputInitUserParams: initUnknown,
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleOwner,
expectedUserRole: types.UserRoleOwner,
expectedDomain: unknownDomain,
expectedDomainCategory: "",
expectedPrimaryDomainStatus: false,
@ -547,14 +553,14 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
inputClaims: jwtclaims.AuthorizationClaims{
Domain: privateDomain,
UserId: "pvt-domain-user",
DomainCategory: PrivateCategory,
DomainCategory: types.PrivateCategory,
},
inputInitUserParams: defaultInitAccount,
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleOwner,
expectedUserRole: types.UserRoleOwner,
expectedDomain: privateDomain,
expectedDomainCategory: PrivateCategory,
expectedDomainCategory: types.PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: "pvt-domain-user",
expectedUsers: []string{"pvt-domain-user"},
@ -564,15 +570,15 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
inputClaims: jwtclaims.AuthorizationClaims{
Domain: privateDomain,
UserId: "new-pvt-domain-user",
DomainCategory: PrivateCategory,
DomainCategory: types.PrivateCategory,
},
inputUpdateAttrs: true,
inputInitUserParams: privateInitAccount,
testingFunc: require.Equal,
expectedMSG: "account IDs should match",
expectedUserRole: UserRoleUser,
expectedUserRole: types.UserRoleUser,
expectedDomain: privateDomain,
expectedDomainCategory: PrivateCategory,
expectedDomainCategory: types.PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId,
expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"},
@ -582,14 +588,14 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
inputClaims: jwtclaims.AuthorizationClaims{
Domain: defaultInitAccount.Domain,
UserId: defaultInitAccount.UserId,
DomainCategory: PrivateCategory,
DomainCategory: types.PrivateCategory,
},
inputInitUserParams: defaultInitAccount,
testingFunc: require.Equal,
expectedMSG: "account IDs should match",
expectedUserRole: UserRoleOwner,
expectedUserRole: types.UserRoleOwner,
expectedDomain: defaultInitAccount.Domain,
expectedDomainCategory: PrivateCategory,
expectedDomainCategory: types.PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId,
expectedUsers: []string{defaultInitAccount.UserId},
@ -599,15 +605,15 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
inputClaims: jwtclaims.AuthorizationClaims{
Domain: defaultInitAccount.Domain,
UserId: defaultInitAccount.UserId,
DomainCategory: PrivateCategory,
DomainCategory: types.PrivateCategory,
},
inputUpdateClaimAccount: true,
inputInitUserParams: defaultInitAccount,
testingFunc: require.Equal,
expectedMSG: "account IDs should match",
expectedUserRole: UserRoleOwner,
expectedUserRole: types.UserRoleOwner,
expectedDomain: defaultInitAccount.Domain,
expectedDomainCategory: PrivateCategory,
expectedDomainCategory: types.PrivateCategory,
expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId,
expectedUsers: []string{defaultInitAccount.UserId},
@ -617,12 +623,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
inputClaims: jwtclaims.AuthorizationClaims{
Domain: "",
UserId: "pvt-domain-user",
DomainCategory: PrivateCategory,
DomainCategory: types.PrivateCategory,
},
inputInitUserParams: defaultInitAccount,
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleOwner,
expectedUserRole: types.UserRoleOwner,
expectedDomain: "",
expectedDomainCategory: "",
expectedPrimaryDomainStatus: false,
@ -752,22 +758,26 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
}
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
store := newStore(t)
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
account.Users["someUser"] = &User{
account.Users["someUser"] = &types.User{
Id: "someUser",
PATs: map[string]*PersonalAccessToken{
PATs: map[string]*types.PersonalAccessToken{
"tokenId": {
ID: "tokenId",
HashedToken: encodedHashedToken,
},
},
}
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -787,15 +797,20 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
}
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
store := newStore(t)
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
account.Users["someUser"] = &User{
account.Users["someUser"] = &types.User{
Id: "someUser",
PATs: map[string]*PersonalAccessToken{
PATs: map[string]*types.PersonalAccessToken{
"tokenId": {
ID: "tokenId",
HashedToken: encodedHashedToken,
@ -803,7 +818,7 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
},
},
}
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -905,7 +920,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
return
}
exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID)
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID)
assert.NoError(t, err)
assert.True(t, exists, "expected to get existing account after creation using userid")
@ -915,7 +930,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
}
}
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) {
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*types.Account, error) {
account := newAccountWithId(context.Background(), accountID, userID, domain)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
@ -991,13 +1006,13 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
claims := jwtclaims.AuthorizationClaims{
Domain: "example.com",
UserId: "pvt-domain-user",
DomainCategory: PrivateCategory,
DomainCategory: types.PrivateCategory,
}
publicClaims := jwtclaims.AuthorizationClaims{
Domain: "test.com",
UserId: "public-domain-user",
DomainCategory: PublicCategory,
DomainCategory: types.PublicCategory,
}
am, err := createManager(b)
@ -1075,13 +1090,13 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
}
func genUsers(p string, n int) map[string]*User {
users := map[string]*User{}
func genUsers(p string, n int) map[string]*types.User {
users := map[string]*types.User{}
now := time.Now()
for i := 0; i < n; i++ {
users[fmt.Sprintf("%s-%d", p, i)] = &User{
users[fmt.Sprintf("%s-%d", p, i)] = &types.User{
Id: fmt.Sprintf("%s-%d", p, i),
Role: UserRoleAdmin,
Role: types.UserRoleAdmin,
LastLogin: now,
CreatedAt: now,
Issued: "api",
@ -1106,7 +1121,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
serial := account.Network.CurrentSerial() // should be 0
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
t.Fatal("error creating setup key")
return
@ -1243,15 +1258,15 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
return
}
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
Action: types.PolicyTrafficActionAccept,
},
},
})
@ -1335,15 +1350,15 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
}
}()
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
Action: types.PolicyTrafficActionAccept,
},
},
})
@ -1368,15 +1383,15 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
return
}
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
Action: types.PolicyTrafficActionAccept,
},
},
})
@ -1427,15 +1442,15 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
return
}
policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
Action: types.PolicyTrafficActionAccept,
},
},
})
@ -1483,7 +1498,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
t.Fatal(err)
}
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
t.Fatal("error creating setup key")
return
@ -1558,7 +1573,7 @@ func TestGetUsersFromAccount(t *testing.T) {
t.Fatal(err)
}
users := map[string]*User{"1": {Id: "1", Role: UserRoleOwner}, "2": {Id: "2", Role: "user"}, "3": {Id: "3", Role: "user"}}
users := map[string]*types.User{"1": {Id: "1", Role: types.UserRoleOwner}, "2": {Id: "2", Role: "user"}, "3": {Id: "3", Role: "user"}}
accountId := "test_account_id"
account, err := createAccount(manager, accountId, users["1"].Id, "")
@ -1590,7 +1605,7 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
if err != nil {
t.Fatal(err)
}
account := &Account{
account := &types.Account{
Routes: map[route.ID]*route.Route{
"route-1": {
ID: "route-1",
@ -1637,7 +1652,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
if err != nil {
t.Fatal(err)
}
account := &Account{
account := &types.Account{
Peers: map[string]*nbpeer.Peer{
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
},
@ -1682,7 +1697,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
},
}
routes := account.getRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
assert.Len(t, routes, 2)
routeIDs := make(map[route.ID]struct{}, 2)
@ -1692,26 +1707,26 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
assert.Contains(t, routeIDs, route.ID("route-2"))
assert.Contains(t, routeIDs, route.ID("route-3"))
emptyRoutes := account.getRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
assert.Len(t, emptyRoutes, 0)
}
func TestAccount_Copy(t *testing.T) {
account := &Account{
account := &types.Account{
Id: "account1",
CreatedBy: "tester",
CreatedAt: time.Now().UTC(),
Domain: "test.com",
DomainCategory: "public",
IsDomainPrimaryAccount: true,
SetupKeys: map[string]*SetupKey{
SetupKeys: map[string]*types.SetupKey{
"setup1": {
Id: "setup1",
AutoGroups: []string{"group1"},
},
},
Network: &Network{
Network: &types.Network{
Identifier: "net1",
},
Peers: map[string]*nbpeer.Peer{
@ -1724,12 +1739,12 @@ func TestAccount_Copy(t *testing.T) {
},
},
},
Users: map[string]*User{
Users: map[string]*types.User{
"user1": {
Id: "user1",
Role: UserRoleAdmin,
Role: types.UserRoleAdmin,
AutoGroups: []string{"group1"},
PATs: map[string]*PersonalAccessToken{
PATs: map[string]*types.PersonalAccessToken{
"pat1": {
ID: "pat1",
Name: "First PAT",
@ -1748,11 +1763,11 @@ func TestAccount_Copy(t *testing.T) {
Peers: []string{"peer1"},
},
},
Policies: []*Policy{
Policies: []*types.Policy{
{
ID: "policy1",
Enabled: true,
Rules: make([]*PolicyRule, 0),
Rules: make([]*types.PolicyRule, 0),
SourcePostureChecks: make([]string, 0),
},
},
@ -1772,19 +1787,19 @@ func TestAccount_Copy(t *testing.T) {
NameServers: []nbdns.NameServer{},
},
},
DNSSettings: DNSSettings{DisabledManagementGroups: []string{}},
DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{}},
PostureChecks: []*posture.Checks{
{
ID: "posture Checks1",
},
},
Settings: &Settings{},
Networks: []*networks.Network{
Settings: &types.Settings{},
Networks: []*networkTypes.Network{
{
ID: "network1",
},
},
NetworkRouters: []*networks.NetworkRouter{
NetworkRouters: []*routerTypes.NetworkRouter{
{
ID: "router1",
NetworkID: "network1",
@ -1793,7 +1808,7 @@ func TestAccount_Copy(t *testing.T) {
Metric: 0,
},
},
NetworkResources: []*networks.NetworkResource{
NetworkResources: []*resourceTypes.NetworkResource{
{
ID: "resource1",
NetworkID: "network1",
@ -1854,7 +1869,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings")
assert.NotNil(t, settings)
@ -1887,7 +1902,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@ -1935,7 +1950,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@ -2004,7 +2019,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
},
}
// enabling PeerLoginExpirationEnabled should trigger the expiration job
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@ -2017,7 +2032,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
wg.Add(1)
// disabling PeerLoginExpirationEnabled should trigger cancel
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
})
@ -2035,7 +2050,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
})
@ -2043,19 +2058,19 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings")
assert.False(t, settings.PeerLoginExpirationEnabled)
assert.Equal(t, settings.PeerLoginExpiration, time.Hour)
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Second,
PeerLoginExpirationEnabled: false,
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour * 24 * 181,
PeerLoginExpirationEnabled: false,
})
@ -2128,9 +2143,9 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
account := &Account{
account := &types.Account{
Peers: testCase.peers,
Settings: &Settings{
Settings: &types.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
@ -2212,9 +2227,9 @@ func TestAccount_GetInactivePeers(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
account := &Account{
account := &types.Account{
Peers: testCase.peers,
Settings: &Settings{
Settings: &types.Settings{
PeerInactivityExpirationEnabled: true,
PeerInactivityExpiration: time.Second,
},
@ -2279,7 +2294,7 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
account := &Account{
account := &types.Account{
Peers: testCase.peers,
}
@ -2348,7 +2363,7 @@ func TestAccount_GetPeersWithInactivity(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
account := &Account{
account := &types.Account{
Peers: testCase.peers,
}
@ -2512,9 +2527,9 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
account := &Account{
account := &types.Account{
Peers: testCase.peers,
Settings: &Settings{PeerLoginExpiration: testCase.expiration, PeerLoginExpirationEnabled: testCase.expirationEnabled},
Settings: &types.Settings{PeerLoginExpiration: testCase.expiration, PeerLoginExpirationEnabled: testCase.expirationEnabled},
}
expiration, ok := account.GetNextPeerExpiration()
@ -2672,9 +2687,9 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) {
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
account := &Account{
account := &types.Account{
Peers: testCase.peers,
Settings: &Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled},
Settings: &types.Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled},
}
expiration, ok := account.GetNextInactivePeerExpiration()
@ -2693,7 +2708,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
require.NoError(t, err, "unable to create account manager")
// create a new account
account := &Account{
account := &types.Account{
Id: "accountID",
Peers: map[string]*nbpeer.Peer{
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
@ -2705,8 +2720,8 @@ func TestAccount_SetJWTGroups(t *testing.T) {
Groups: map[string]*group.Group{
"group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}},
},
Settings: &Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"},
Users: map[string]*User{
Settings: &types.Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"},
Users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "accountID"},
"user2": {Id: "user2", AccountID: "accountID"},
},
@ -2722,7 +2737,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Empty(t, user.AutoGroups, "auto groups must be empty")
})
@ -2735,18 +2750,18 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err := manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})
t.Run("jwt match existing api group in user auto groups", func(t *testing.T) {
account.Users["user1"].AutoGroups = []string{"group1"}
assert.NoError(t, manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, account.Users["user1"]))
assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"]))
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
@ -2755,11 +2770,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})
@ -2772,7 +2787,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
})
@ -2785,7 +2800,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
})
@ -2798,11 +2813,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID")
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID")
assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user2")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "new group should be added")
})
@ -2815,7 +2830,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present")
@ -2823,7 +2838,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
}
func TestAccount_UserGroupsAddToPeers(t *testing.T) {
account := &Account{
account := &types.Account{
Peers: map[string]*nbpeer.Peer{
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
@ -2836,7 +2851,7 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) {
"group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{}},
"group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{}},
},
Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}},
Users: map[string]*types.User{"user1": {Id: "user1"}, "user2": {Id: "user2"}},
}
t.Run("add groups", func(t *testing.T) {
@ -2859,7 +2874,7 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) {
}
func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
account := &Account{
account := &types.Account{
Peers: map[string]*nbpeer.Peer{
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
@ -2872,7 +2887,7 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
"group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}},
"group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{"peer4", "peer5"}},
},
Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}},
Users: map[string]*types.User{"user1": {Id: "user1"}, "user2": {Id: "user2"}},
}
t.Run("remove groups", func(t *testing.T) {
@ -2915,10 +2930,10 @@ func createManager(t TB) (*DefaultAccountManager, error) {
return manager, nil
}
func createStore(t TB) (Store, error) {
func createStore(t TB) (store.Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir)
if err != nil {
return nil, err
}
@ -2941,7 +2956,7 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
}
}
func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
t.Helper()
manager, err := createManager(t)
@ -2954,12 +2969,12 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpee
t.Fatal(err)
}
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
t.Fatal("error creating setup key")
}
getPeer := func(manager *DefaultAccountManager, setupKey *SetupKey) *nbpeer.Peer {
getPeer := func(manager *DefaultAccountManager, setupKey *types.SetupKey) *nbpeer.Peer {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)

View File

@ -5,6 +5,7 @@ import (
"net/url"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
)
@ -156,7 +157,7 @@ type ProviderConfig struct {
// StoreConfig contains Store configuration
type StoreConfig struct {
Engine StoreEngine
Engine store.Engine
}
// ReverseProxy contains reverse proxy configuration in front of management.

View File

@ -2,9 +2,7 @@ package server
import (
"context"
"fmt"
"slices"
"strconv"
"sync"
log "github.com/sirupsen/logrus"
@ -12,12 +10,12 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
)
const defaultTTL = 300
// DNSConfigCache is a thread-safe cache for DNS configuration components
type DNSConfigCache struct {
CustomZones sync.Map
@ -62,26 +60,9 @@ func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerG
c.NameServerGroups.Store(key, value)
}
type lookupMap map[string]struct{}
// DNSSettings defines dns settings at the account level
type DNSSettings struct {
// DisabledManagementGroups groups whose DNS management is disabled
DisabledManagementGroups []string `gorm:"serializer:json"`
}
// Copy returns a copy of the DNS settings
func (d DNSSettings) Copy() DNSSettings {
settings := DNSSettings{
DisabledManagementGroups: make([]string, len(d.DisabledManagementGroups)),
}
copy(settings.DisabledManagementGroups, d.DisabledManagementGroups)
return settings
}
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
@ -94,16 +75,16 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID)
}
// SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error {
if dnsSettingsToSave == nil {
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return err
}
@ -119,18 +100,18 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
var updateAccountPeers bool
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
return err
}
oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
oldSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {
return err
}
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
addedGroups := util.Difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := util.Difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
if err != nil {
@ -140,11 +121,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
eventsToStore = append(eventsToStore, events...)
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave)
return transaction.SaveDNSSettings(ctx, store.LockingStrengthUpdate, accountID, dnsSettingsToSave)
})
if err != nil {
return err
@ -162,11 +143,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
}
// prepareDNSSettingsEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() {
func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction store.Store, accountID, userID string, addedGroups, removedGroups []string) []func() {
var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups)
if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
return nil
@ -203,7 +184,7 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
}
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups)
if err != nil {
return false, err
@ -217,12 +198,12 @@ func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, acc
}
// validateDNSSettings validates the DNS settings.
func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error {
func validateDNSSettings(ctx context.Context, transaction store.Store, accountID string, settings *types.DNSSettings) error {
if len(settings.DisabledManagementGroups) == 0 {
return nil
}
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, settings.DisabledManagementGroups)
if err != nil {
return err
}
@ -298,81 +279,3 @@ func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameSe
}
return protoGroup
}
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
groupList := account.getPeerGroups(peerID)
var peerNSGroups []*nbdns.NameServerGroup
for _, nsGroup := range account.NameServerGroups {
if !nsGroup.Enabled {
continue
}
for _, gID := range nsGroup.Groups {
_, found := groupList[gID]
if found {
if !peerIsNameserver(account.GetPeer(peerID), nsGroup) {
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
break
}
}
}
}
return peerNSGroups
}
// peerIsNameserver returns true if the peer is a nameserver for a nsGroup
func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
for _, ns := range nsGroup.NameServers {
if peer.IP.Equal(ns.IP.AsSlice()) {
return true
}
}
return false
}
func addPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels lookupMap) {
for _, peer := range account.Peers {
label, err := getPeerHostLabel(peer.Name, peerLabels)
if err != nil {
log.WithContext(ctx).Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err)
label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels)
if err != nil {
log.WithContext(ctx).Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err)
continue
}
}
peer.DNSLabel = label
peerLabels[label] = struct{}{}
}
}
func getPeerHostLabel(name string, peerLabels lookupMap) (string, error) {
label, err := nbdns.GetParsedDomainLabel(name)
if err != nil {
return "", err
}
uniqueLabel := getUniqueHostLabel(label, peerLabels)
if uniqueLabel == "" {
return "", fmt.Errorf("couldn't find a unique valid label for %s, parsed label %s", name, label)
}
return uniqueLabel, nil
}
// getUniqueHostLabel look for a unique host label, and if doesn't find add a suffix up to 999
func getUniqueHostLabel(name string, peerLabels lookupMap) string {
_, found := peerLabels[name]
if !found {
return name
}
for i := 1; i < 1000; i++ {
nameWithSuffix := name + "-" + strconv.Itoa(i)
_, found = peerLabels[nameWithSuffix]
if !found {
return nameWithSuffix
}
}
return ""
}

View File

@ -11,7 +11,9 @@ import (
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/stretchr/testify/require"
@ -53,7 +55,7 @@ func TestGetDNSSettings(t *testing.T) {
t.Fatal("DNS settings for new accounts shouldn't return nil")
}
account.DNSSettings = DNSSettings{
account.DNSSettings = types.DNSSettings{
DisabledManagementGroups: []string{group1ID},
}
@ -86,20 +88,20 @@ func TestSaveDNSSettings(t *testing.T) {
testCases := []struct {
name string
userID string
inputSettings *DNSSettings
inputSettings *types.DNSSettings
shouldFail bool
}{
{
name: "Saving As Admin Should Be OK",
userID: dnsAdminUserID,
inputSettings: &DNSSettings{
inputSettings: &types.DNSSettings{
DisabledManagementGroups: []string{dnsGroup1ID},
},
},
{
name: "Should Not Update Settings As Regular User",
userID: dnsRegularUserID,
inputSettings: &DNSSettings{
inputSettings: &types.DNSSettings{
DisabledManagementGroups: []string{dnsGroup1ID},
},
shouldFail: true,
@ -113,7 +115,7 @@ func TestSaveDNSSettings(t *testing.T) {
{
name: "Should Not Update Settings If Group Is Invalid",
userID: dnsAdminUserID,
inputSettings: &DNSSettings{
inputSettings: &types.DNSSettings{
DisabledManagementGroups: []string{"non-existing-group"},
},
shouldFail: true,
@ -210,10 +212,10 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics)
}
func createDNSStore(t *testing.T) (Store, error) {
func createDNSStore(t *testing.T) (store.Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir)
if err != nil {
return nil, err
}
@ -222,7 +224,7 @@ func createDNSStore(t *testing.T) (Store, error) {
return store, nil
}
func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, error) {
t.Helper()
peer1 := &nbpeer.Peer{
Key: dnsPeer1Key,
@ -259,9 +261,9 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain)
account.Users[dnsRegularUserID] = &User{
account.Users[dnsRegularUserID] = &types.User{
Id: dnsRegularUserID,
Role: UserRoleUser,
Role: types.UserRoleUser,
}
err := am.Store.SaveAccount(context.Background(), account)
@ -510,7 +512,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{
DisabledManagementGroups: []string{"groupA"},
})
assert.NoError(t, err)
@ -589,7 +591,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{
DisabledManagementGroups: []string{"groupA", "groupB"},
})
assert.NoError(t, err)
@ -609,7 +611,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{
DisabledManagementGroups: []string{"groupA"},
})
assert.NoError(t, err)
@ -629,7 +631,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
close(done)
}()
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{
DisabledManagementGroups: []string{},
})
assert.NoError(t, err)

View File

@ -9,6 +9,8 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
const (
@ -21,7 +23,7 @@ var (
type ephemeralPeer struct {
id string
account *Account
account *types.Account
deadline time.Time
next *ephemeralPeer
}
@ -32,7 +34,7 @@ type ephemeralPeer struct {
// EphemeralManager keep a list of ephemeral peers. After ephemeralLifeTime inactivity the peer will be deleted
// automatically. Inactivity means the peer disconnected from the Management server.
type EphemeralManager struct {
store Store
store store.Store
accountManager AccountManager
headPeer *ephemeralPeer
@ -42,7 +44,7 @@ type EphemeralManager struct {
}
// NewEphemeralManager instantiate new EphemeralManager
func NewEphemeralManager(store Store, accountManager AccountManager) *EphemeralManager {
func NewEphemeralManager(store store.Store, accountManager AccountManager) *EphemeralManager {
return &EphemeralManager{
store: store,
accountManager: accountManager,
@ -177,7 +179,7 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
}
}
func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) {
func (e *EphemeralManager) addPeer(id string, account *types.Account, deadline time.Time) {
ep := &ephemeralPeer{
id: id,
account: account,

View File

@ -8,18 +8,20 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
type MockStore struct {
Store
account *Account
store.Store
account *types.Account
}
func (s *MockStore) GetAllAccounts(_ context.Context) []*Account {
return []*Account{s.account}
func (s *MockStore) GetAllAccounts(_ context.Context) []*types.Account {
return []*types.Account{s.account}
}
func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) {
func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*types.Account, error) {
_, ok := s.account.Peers[peerId]
if ok {
return s.account, nil

View File

@ -10,6 +10,9 @@ import (
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/activity"
@ -28,7 +31,7 @@ func (e *GroupLinkError) Error() string {
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return err
}
@ -49,7 +52,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
return am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID)
}
// GetAllGroups returns all groups in an account
@ -57,12 +60,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
return am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName)
}
// SaveGroup object of the peers
@ -76,7 +79,7 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return err
}
@ -93,7 +96,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
var groupsToSave []*nbgroup.Group
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
@ -113,11 +116,11 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave)
return transaction.SaveGroups(ctx, store.LockingStrengthUpdate, groupsToSave)
})
if err != nil {
return err
@ -135,16 +138,16 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
}
// prepareGroupEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction store.Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
var eventsToStore []func()
addedPeers := make([]string, 0)
removedPeers := make([]string, 0)
oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID)
if err == nil && oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
} else {
addedPeers = append(addedPeers, newGroup.Peers...)
eventsToStore = append(eventsToStore, func() {
@ -153,7 +156,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
}
modifiedPeers := slices.Concat(addedPeers, removedPeers)
peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers)
peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, modifiedPeers)
if err != nil {
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
return nil
@ -194,21 +197,6 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
return eventsToStore
}
// difference returns the elements in `a` that aren't in `b`.
func difference(a, b []string) []string {
mb := make(map[string]struct{}, len(b))
for _, x := range b {
mb[x] = struct{}{}
}
var diff []string
for _, x := range a {
if _, found := mb[x]; !found {
diff = append(diff, x)
}
}
return diff
}
// DeleteGroup object of the peers.
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
@ -223,7 +211,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return err
}
@ -240,9 +228,9 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
var groupIDsToDelete []string
var deletedGroups []*nbgroup.Group
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range groupIDs {
group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID)
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthUpdate, accountID, groupID)
if err != nil {
allErrors = errors.Join(allErrors, err)
continue
@ -257,11 +245,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
deletedGroups = append(deletedGroups, group)
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete)
return transaction.DeleteGroups(ctx, store.LockingStrengthUpdate, accountID, groupIDsToDelete)
})
if err != nil {
return err
@ -283,8 +271,8 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
@ -298,11 +286,11 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
})
if err != nil {
return err
@ -324,8 +312,8 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
@ -339,11 +327,11 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group)
})
if err != nil {
return err
@ -357,13 +345,13 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
}
// validateNewGroup validates the new group for existence and required fields.
func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error {
func validateNewGroup(ctx context.Context, transaction store.Store, accountID string, newGroup *nbgroup.Group) error {
if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthShare, accountID, newGroup.Name)
if err != nil {
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
return err
@ -380,7 +368,7 @@ func validateNewGroup(ctx context.Context, transaction Store, accountID string,
}
for _, peerID := range newGroup.Peers {
_, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID)
if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
@ -389,14 +377,14 @@ func validateNewGroup(ctx context.Context, transaction Store, accountID string,
return nil
}
func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error {
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration {
executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return err
}
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
if executingUser.Role != types.UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
}
}
@ -429,8 +417,8 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.
}
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *nbgroup.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
@ -439,7 +427,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *n
return &GroupLinkError{"disabled DNS management groups", group.Name}
}
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, group.AccountID)
if err != nil {
return err
}
@ -452,8 +440,8 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *n
}
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
@ -469,8 +457,8 @@ func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID stri
}
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) {
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
@ -487,8 +475,8 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID str
}
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
@ -506,8 +494,8 @@ func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string
}
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
@ -522,8 +510,8 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID s
}
// isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) {
users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID)
func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) {
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
@ -538,12 +526,12 @@ func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID strin
}
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return false, nil
}
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return false, err
}
@ -566,7 +554,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI
return false, nil
}
func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool {
func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
return true
@ -576,8 +564,8 @@ func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []s
}
// anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs)
func anyGroupHasPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs)
if err != nil {
return false, err
}

View File

@ -14,6 +14,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
@ -267,7 +268,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
}
}
func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *Account, error) {
func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *types.Account, error) {
accountID := "testingAcc"
domain := "example.com"
@ -342,9 +343,9 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
Groups: []string{groupForNameServerGroups.ID},
}
policy := &Policy{
policy := &types.Policy{
ID: "example policy",
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
ID: "example policy rule",
Destinations: []string{groupForPolicies.ID},
@ -352,12 +353,12 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A
},
}
setupKey := &SetupKey{
setupKey := &types.SetupKey{
Id: "example setup key",
AutoGroups: []string{groupForSetupKeys.ID},
}
user := &User{
user := &types.User{
Id: "example user",
AutoGroups: []string{groupForUsers.ID},
}
@ -500,15 +501,15 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
})
// adding a group to policy
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
Action: types.PolicyTrafficActionAccept,
},
},
})
@ -648,7 +649,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
// Saving a group linked to dns settings should update account peers and send peer update
t.Run("saving group linked to dns settings", func(t *testing.T) {
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{
DisabledManagementGroups: []string{"groupD"},
})
assert.NoError(t, err)

View File

@ -25,6 +25,7 @@ import (
"github.com/netbirdio/netbird/management/server/posture"
internalStatus "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
)
// GRPCServer an instance of a Management gRPC API server
@ -599,7 +600,7 @@ func toWiretrusteeConfig(config *Config, turnCredentials *Token, relayToken *Tok
}
}
func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.PeerConfig {
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string) *proto.PeerConfig {
netmask, _ := network.Net.Mask.Size()
fqdn := peer.FQDN(dnsName)
return &proto.PeerConfig{
@ -609,7 +610,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.Pe
}
}
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse {
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse {
response := &proto.SyncResponse{
WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials, relayCredentials),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName),
@ -661,7 +662,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
}
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
var err error
var turnToken *Token

View File

@ -1176,6 +1176,105 @@ components:
- id
- network_type
- $ref: '#/components/schemas/RouteRequest'
NetworkRequest:
type: object
properties:
name:
description: Network name
type: string
example: Remote Network 1
description:
description: Network description
type: string
example: A remote network that needs to be accessed
required:
- name
Network:
allOf:
- type: object
properties:
id:
description: Network ID
type: string
example: chacdk86lnnboviihd7g
required:
- id
- $ref: '#/components/schemas/NetworkRequest'
NetworkResourceRequest:
type: object
properties:
name:
description: Network resource name
type: string
example: Remote Resource 1
description:
description: Network resource description
type: string
example: A remote resource inside network 1
address:
description: Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or a domain like example.com)
type: string
example: "1.1.1.1"
required:
- name
- address
NetworkResource:
allOf:
- type: object
properties:
id:
description: Network Resource ID
type: string
example: chacdk86lnnboviihd7g
type:
description: Network resource type based of the address
type: string
enum: [ "host", "subnet", "domain"]
example: host
required:
- id
- type
- $ref: '#/components/schemas/NetworkResourceRequest'
NetworkRouterRequest:
type: object
properties:
peer:
description: Peer Identifier associated with route. This property can not be set together with `peer_groups`
type: string
example: chacbco6lnnbn6cg5s91
peer_groups:
description: Peers Group Identifier associated with route. This property can not be set together with `peer`
type: array
items:
type: string
example: chacbco6lnnbn6cg5s91
metric:
description: Route metric number. Lowest number has higher priority
type: integer
maximum: 9999
minimum: 1
example: 9999
masquerade:
description: Indicate if peer should masquerade traffic to this route's prefix
type: boolean
example: true
required:
# Only one property has to be set
#- peer
#- peer_groups
- metric
- masquerade
NetworkRouter:
allOf:
- type: object
properties:
id:
description: Network Router Id
type: string
example: chacdk86lnnboviihd7g
required:
- id
- $ref: '#/components/schemas/NetworkRouterRequest'
Nameserver:
type: object
properties:
@ -2460,6 +2559,502 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/networks:
get:
summary: List all Networks
description: Returns a list of all networks
tags: [ Networks ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON Array of Networks
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/Network'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
post:
summary: Create a Network
description: Creates a Network
tags: [ Networks ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
requestBody:
description: New Network request
content:
'application/json':
schema:
$ref: '#/components/schemas/NetworkRequest'
responses:
'200':
description: A Network Object
content:
application/json:
schema:
$ref: '#/components/schemas/Network'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/networks/{networkId}:
get:
summary: Retrieve a Network
description: Get information about a Network
tags: [ Networks ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: networkId
required: true
schema:
type: string
description: The unique identifier of a network
responses:
'200':
description: A Network object
content:
application/json:
schema:
$ref: '#/components/schemas/Network'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
put:
summary: Update a Network
description: Update/Replace a Network
tags: [ Networks ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: networkId
required: true
schema:
type: string
description: The unique identifier of a network
requestBody:
description: Update Network request
content:
application/json:
schema:
$ref: '#/components/schemas/NetworkRequest'
responses:
'200':
description: A Network object
content:
application/json:
schema:
$ref: '#/components/schemas/Network'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
delete:
summary: Delete a Network
description: Delete a network
tags: [ Networks ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: networkId
required: true
schema:
type: string
description: The unique identifier of a network
responses:
'200':
description: Delete status code
content: { }
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/networks/{networkId}/resources:
get:
summary: List all Network Resources
description: Returns a list of all resources in a network
tags: [ Networks ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: networkId
required: true
schema:
type: string
description: The unique identifier of a network
responses:
'200':
description: A JSON Array of Resources
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/NetworkResource'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
post:
summary: Create a Network Resource
description: Creates a Network Resource
tags: [ Networks ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: networkId
required: true
schema:
type: string
description: The unique identifier of a network
requestBody:
description: New Network Resource request
content:
'application/json':
schema:
$ref: '#/components/schemas/NetworkResourceRequest'
responses:
'200':
description: A Network Resource Object
content:
application/json:
schema:
$ref: '#/components/schemas/NetworkResource'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/networks/{networkId}/resources/{resourceId}:
get:
summary: Retrieve a Network Resource
description: Get information about a Network Resource
tags: [ Networks ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: networkId
required: true
schema:
type: string
description: The unique identifier of a network
- in: path
name: resourceId
required: true
schema:
type: string
description: The unique identifier of a network resource
responses:
'200':
description: A Network Resource object
content:
application/json:
schema:
$ref: '#/components/schemas/NetworkResource'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
put:
summary: Update a Network Resource
description: Update a Network Resource
tags: [ Networks ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: networkId
required: true
schema:
type: string
description: The unique identifier of a network
- in: path
name: resourceId
required: true
schema:
type: string
description: The unique identifier of a resource
requestBody:
description: Update Network Resource request
content:
'application/json':
schema:
$ref: '#/components/schemas/NetworkResourceRequest'
responses:
'200':
description: A Network Resource object
content:
application/json:
schema:
$ref: '#/components/schemas/NetworkResource'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
delete:
summary: Delete a Network Resource
description: Delete a network resource
tags: [ Networks ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: networkId
required: true
schema:
type: string
description: The unique identifier of a network
- in: path
name: resourceId
required: true
schema:
type: string
description: The unique identifier of a network resource
responses:
'200':
description: Delete status code
content: { }
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/networks/{networkId}/routers:
get:
summary: List all Network Routers
description: Returns a list of all routers in a network
tags: [ Networks ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: networkId
required: true
schema:
type: string
description: The unique identifier of a network
responses:
'200':
description: A JSON Array of Routers
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/NetworkRouter'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
post:
summary: Create a Network Router
description: Creates a Network Router
tags: [ Networks ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: networkId
required: true
schema:
type: string
description: The unique identifier of a network
requestBody:
description: New Network Router request
content:
'application/json':
schema:
$ref: '#/components/schemas/NetworkRouterRequest'
responses:
'200':
description: A Router Object
content:
application/json:
schema:
$ref: '#/components/schemas/NetworkRouter'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/networks/{networkId}/routers/{routerId}:
get:
summary: Retrieve a Network Router
description: Get information about a Network Router
tags: [ Networks ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: networkId
required: true
schema:
type: string
description: The unique identifier of a network
- in: path
name: routerId
required: true
schema:
type: string
description: The unique identifier of a router
responses:
'200':
description: A Router object
content:
application/json:
schema:
$ref: '#/components/schemas/NetworkRouter'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
put:
summary: Update a Network Router
description: Update a Network Router
tags: [ Networks ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: networkId
required: true
schema:
type: string
description: The unique identifier of a network
- in: path
name: routerId
required: true
schema:
type: string
description: The unique identifier of a router
requestBody:
description: Update Network Router request
content:
'application/json':
schema:
$ref: '#/components/schemas/NetworkRouterRequest'
responses:
'200':
description: A Router object
content:
application/json:
schema:
$ref: '#/components/schemas/NetworkRouter'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
delete:
summary: Delete a Network Router
description: Delete a network router
tags: [ Networks ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: networkId
required: true
schema:
type: string
description: The unique identifier of a network
- in: path
name: routerId
required: true
schema:
type: string
description: The unique identifier of a router
responses:
'200':
description: Delete status code
content: { }
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/dns/nameservers:
get:
summary: List all Nameserver Groups

View File

@ -88,6 +88,13 @@ const (
NameserverNsTypeUdp NameserverNsType = "udp"
)
// Defines values for NetworkResourceType.
const (
NetworkResourceTypeDomain NetworkResourceType = "domain"
NetworkResourceTypeHost NetworkResourceType = "host"
NetworkResourceTypeSubnet NetworkResourceType = "subnet"
)
// Defines values for PeerNetworkRangeCheckAction.
const (
PeerNetworkRangeCheckActionAllow PeerNetworkRangeCheckAction = "allow"
@ -494,6 +501,93 @@ type NameserverGroupRequest struct {
SearchDomainsEnabled bool `json:"search_domains_enabled"`
}
// Network defines model for Network.
type Network struct {
// Description Network description
Description *string `json:"description,omitempty"`
// Id Network ID
Id string `json:"id"`
// Name Network name
Name string `json:"name"`
}
// NetworkRequest defines model for NetworkRequest.
type NetworkRequest struct {
// Description Network description
Description *string `json:"description,omitempty"`
// Name Network name
Name string `json:"name"`
}
// NetworkResource defines model for NetworkResource.
type NetworkResource struct {
// Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or a domain like example.com)
Address string `json:"address"`
// Description Network resource description
Description *string `json:"description,omitempty"`
// Id Network Resource ID
Id string `json:"id"`
// Name Network resource name
Name string `json:"name"`
// Type Network resource type based of the address
Type NetworkResourceType `json:"type"`
}
// NetworkResourceType Network resource type based of the address
type NetworkResourceType string
// NetworkResourceRequest defines model for NetworkResourceRequest.
type NetworkResourceRequest struct {
// Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or a domain like example.com)
Address string `json:"address"`
// Description Network resource description
Description *string `json:"description,omitempty"`
// Name Network resource name
Name string `json:"name"`
}
// NetworkRouter defines model for NetworkRouter.
type NetworkRouter struct {
// Id Network Router Id
Id string `json:"id"`
// Masquerade Indicate if peer should masquerade traffic to this route's prefix
Masquerade bool `json:"masquerade"`
// Metric Route metric number. Lowest number has higher priority
Metric int `json:"metric"`
// Peer Peer Identifier associated with route. This property can not be set together with `peer_groups`
Peer *string `json:"peer,omitempty"`
// PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer`
PeerGroups *[]string `json:"peer_groups,omitempty"`
}
// NetworkRouterRequest defines model for NetworkRouterRequest.
type NetworkRouterRequest struct {
// Masquerade Indicate if peer should masquerade traffic to this route's prefix
Masquerade bool `json:"masquerade"`
// Metric Route metric number. Lowest number has higher priority
Metric int `json:"metric"`
// Peer Peer Identifier associated with route. This property can not be set together with `peer_groups`
Peer *string `json:"peer,omitempty"`
// PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer`
PeerGroups *[]string `json:"peer_groups,omitempty"`
}
// OSVersionCheck Posture check for the version of operating system
type OSVersionCheck struct {
// Android Posture check for the version of operating system
@ -1292,6 +1386,24 @@ type PostApiGroupsJSONRequestBody = GroupRequest
// PutApiGroupsGroupIdJSONRequestBody defines body for PutApiGroupsGroupId for application/json ContentType.
type PutApiGroupsGroupIdJSONRequestBody = GroupRequest
// PostApiNetworksJSONRequestBody defines body for PostApiNetworks for application/json ContentType.
type PostApiNetworksJSONRequestBody = NetworkRequest
// PutApiNetworksNetworkIdJSONRequestBody defines body for PutApiNetworksNetworkId for application/json ContentType.
type PutApiNetworksNetworkIdJSONRequestBody = NetworkRequest
// PostApiNetworksNetworkIdResourcesJSONRequestBody defines body for PostApiNetworksNetworkIdResources for application/json ContentType.
type PostApiNetworksNetworkIdResourcesJSONRequestBody = NetworkResourceRequest
// PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody defines body for PutApiNetworksNetworkIdResourcesResourceId for application/json ContentType.
type PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody = NetworkResourceRequest
// PostApiNetworksNetworkIdRoutersJSONRequestBody defines body for PostApiNetworksNetworkIdRouters for application/json ContentType.
type PostApiNetworksNetworkIdRoutersJSONRequestBody = NetworkRouterRequest
// PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody defines body for PutApiNetworksNetworkIdRoutersRouterId for application/json ContentType.
type PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody = NetworkRouterRequest
// PutApiPeersPeerIdJSONRequestBody defines body for PutApiPeersPeerId for application/json ContentType.
type PutApiPeersPeerIdJSONRequestBody = PeerRequest

View File

@ -17,6 +17,7 @@ import (
"github.com/netbirdio/netbird/management/server/http/handlers/dns"
"github.com/netbirdio/netbird/management/server/http/handlers/events"
"github.com/netbirdio/netbird/management/server/http/handlers/groups"
"github.com/netbirdio/netbird/management/server/http/handlers/networks"
"github.com/netbirdio/netbird/management/server/http/handlers/peers"
"github.com/netbirdio/netbird/management/server/http/handlers/policies"
"github.com/netbirdio/netbird/management/server/http/handlers/routes"
@ -93,6 +94,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
routes.AddEndpoints(api.AccountManager, authCfg, router)
dns.AddEndpoints(api.AccountManager, authCfg, router)
events.AddEndpoints(api.AccountManager, authCfg, router)
networks.AddEndpoints(api.AccountManager.GetNetworksManager(), api.AccountManager.GetAccountIDFromToken, authCfg, router)
return rootRouter, nil
}

View File

@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that handles the server.Account HTTP endpoints
@ -82,7 +83,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
return
}
settings := &server.Settings{
settings := &types.Settings{
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
@ -138,7 +139,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toAccountResponse(accountID string, settings *server.Settings) *api.Account {
func toAccountResponse(accountID string, settings *types.Settings) *api.Account {
jwtAllowGroups := settings.JWTAllowGroups
if jwtAllowGroups == nil {
jwtAllowGroups = []string{}

View File

@ -13,23 +13,23 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
func initAccountsTestData(account *server.Account, admin *server.User) *handler {
func initAccountsTestData(account *types.Account, admin *types.User) *handler {
return &handler{
accountManager: &mock_server.MockAccountManager{
GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return account.Id, admin.Id, nil
},
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
return account.Settings, nil
},
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@ -58,19 +58,19 @@ func initAccountsTestData(account *server.Account, admin *server.User) *handler
func TestAccounts_AccountsHandler(t *testing.T) {
accountID := "test_account"
adminUser := server.NewAdminUser("test_user")
adminUser := types.NewAdminUser("test_user")
sr := func(v string) *string { return &v }
br := func(v bool) *bool { return &v }
handler := initAccountsTestData(&server.Account{
handler := initAccountsTestData(&types.Account{
Id: accountID,
Domain: "hotmail.com",
Network: server.NewNetwork(),
Users: map[string]*server.User{
Network: types.NewNetwork(),
Users: map[string]*types.User{
adminUser.Id: adminUser,
},
Settings: &server.Settings{
Settings: &types.Settings{
PeerLoginExpirationEnabled: false,
PeerLoginExpiration: time.Hour,
RegularUsersViewBlocked: true,

View File

@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/types"
)
// dnsSettingsHandler is a handler that returns the DNS settings of the account
@ -81,7 +82,7 @@ func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Re
return
}
updateDNSSettings := &server.DNSSettings{
updateDNSSettings := &types.DNSSettings{
DisabledManagementGroups: req.DisabledManagementGroups,
}

View File

@ -13,10 +13,10 @@ import (
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
)
@ -27,15 +27,15 @@ const (
testDNSSettingsUserID = "test_user"
)
var baseExistingDNSSettings = server.DNSSettings{
var baseExistingDNSSettings = types.DNSSettings{
DisabledManagementGroups: []string{testDNSSettingsExistingGroup},
}
var testingDNSSettingsAccount = &server.Account{
var testingDNSSettingsAccount = &types.Account{
Id: testDNSSettingsAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
testDNSSettingsUserID: server.NewAdminUser("test_user"),
Users: map[string]*types.User{
testDNSSettingsUserID: types.NewAdminUser("test_user"),
},
DNSSettings: baseExistingDNSSettings,
}
@ -43,10 +43,10 @@ var testingDNSSettingsAccount = &server.Account{
func initDNSSettingsTestData() *dnsSettingsHandler {
return &dnsSettingsHandler{
accountManager: &mock_server.MockAccountManager{
GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) {
GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
return &testingDNSSettingsAccount.DNSSettings, nil
},
SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error {
if dnsSettingsToSave != nil {
return nil
}

View File

@ -13,11 +13,11 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
)
func initEventsTestData(account string, events ...*activity.Event) *handler {
@ -32,8 +32,8 @@ func initEventsTestData(account string, events ...*activity.Event) *handler {
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
return make([]*server.UserInfo, 0), nil
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) {
return make([]*types.UserInfo, 0), nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
@ -191,7 +191,7 @@ func TestEvents_GetEvents(t *testing.T) {
},
}
accountID := "test_account"
adminUser := server.NewAdminUser("test_user")
adminUser := types.NewAdminUser("test_user")
events := generateEvents(accountID, adminUser.Id)
handler := initEventsTestData(accountID, events...)

View File

@ -0,0 +1,180 @@
package networks
import (
"context"
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/status"
)
// handler is a handler that returns networks of the account
type handler struct {
networksManager networks.Manager
extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
claimsExtractor *jwtclaims.ClaimsExtractor
}
func AddEndpoints(networksManager networks.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) {
networksHandler := newHandler(networksManager, extractFromToken, authCfg)
router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS")
router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}", networksHandler.updateNetwork).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS")
addRouterEndpoints(networksManager.GetRouterManager(), extractFromToken, authCfg, router)
addResourceEndpoints(networksManager.GetResourceManager(), extractFromToken, authCfg, router)
}
func newHandler(networksManager networks.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *handler {
return &handler{
networksManager: networksManager,
extractFromToken: extractFromToken,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var networkResponse []*api.Network
for _, network := range networks {
networkResponse = append(networkResponse, network.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, networkResponse)
}
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.NetworkRequest
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
network := &types.Network{}
network.FromAPIRequest(&req)
network.AccountID = accountID
network, err = h.networksManager.CreateNetwork(r.Context(), userID, network)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse())
}
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
networkID := vars["networkId"]
if len(networkID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w)
return
}
network, err := h.networksManager.GetNetwork(r.Context(), accountID, userID, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse())
}
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
networkID := vars["networkId"]
if len(networkID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w)
return
}
var req api.NetworkRequest
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
network := &types.Network{}
network.FromAPIRequest(&req)
network.ID = networkID
network.AccountID = accountID
network, err = h.networksManager.UpdateNetwork(r.Context(), userID, network)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse())
}
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
networkID := vars["networkId"]
if len(networkID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w)
return
}
err = h.networksManager.DeleteNetwork(r.Context(), accountID, userID, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@ -0,0 +1,162 @@
package networks
import (
"context"
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
)
type resourceHandler struct {
resourceManager resources.Manager
extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
claimsExtractor *jwtclaims.ClaimsExtractor
}
func addResourceEndpoints(resourcesManager resources.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) {
resourceHandler := newResourceHandler(resourcesManager, extractFromToken, authCfg)
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResources).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.getResource).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.updateResource).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS")
}
func newResourceHandler(resourceManager resources.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *resourceHandler {
return &resourceHandler{
resourceManager: resourceManager,
extractFromToken: extractFromToken,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
func (h *resourceHandler) getAllResources(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
networkID := mux.Vars(r)["networkId"]
resources, err := h.resourceManager.GetAllResources(r.Context(), accountID, userID, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var resourcesResponse []*api.NetworkResource
for _, resource := range resources {
resourcesResponse = append(resourcesResponse, resource.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, resourcesResponse)
}
func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.NetworkResourceRequest
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
resource := &types.NetworkResource{}
resource.FromAPIRequest(&req)
resource.NetworkID = mux.Vars(r)["networkId"]
resource.AccountID = accountID
resource, err = h.resourceManager.CreateResource(r.Context(), userID, resource)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse())
}
func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
networkID := mux.Vars(r)["networkId"]
resourceID := mux.Vars(r)["resourceId"]
resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse())
}
func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.NetworkResourceRequest
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
resource := &types.NetworkResource{}
resource.FromAPIRequest(&req)
resource.ID = mux.Vars(r)["resourceId"]
resource.NetworkID = mux.Vars(r)["networkId"]
resource.AccountID = accountID
resource, err = h.resourceManager.UpdateResource(r.Context(), userID, resource)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse())
}
func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
networkID := mux.Vars(r)["networkId"]
resourceID := mux.Vars(r)["resourceId"]
err = h.resourceManager.DeleteResource(r.Context(), accountID, userID, networkID, resourceID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@ -0,0 +1,165 @@
package networks
import (
"context"
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/networks/routers/types"
)
type routersHandler struct {
routersManager routers.Manager
extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
claimsExtractor *jwtclaims.ClaimsExtractor
}
func addRouterEndpoints(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) {
routersHandler := newRoutersHandler(routersManager, extractFromToken, authCfg)
router.HandleFunc("/networks/{networkId}/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.updateRouter).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS")
}
func newRoutersHandler(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *routersHandler {
return &routersHandler{
routersManager: routersManager,
extractFromToken: extractFromToken,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
networkID := mux.Vars(r)["networkId"]
routers, err := h.routersManager.GetAllRouters(r.Context(), accountID, userID, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var routersResponse []*api.NetworkRouter
for _, router := range routers {
routersResponse = append(routersResponse, router.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, routersResponse)
}
func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
networkID := mux.Vars(r)["networkId"]
var req api.NetworkRouterRequest
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
router := &types.NetworkRouter{}
router.FromAPIRequest(&req)
router.NetworkID = networkID
router.AccountID = accountID
router, err = h.routersManager.CreateRouter(r.Context(), userID, router)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
}
func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
routerID := mux.Vars(r)["routerId"]
networkID := mux.Vars(r)["networkId"]
router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
}
func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.NetworkRouterRequest
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
router := &types.NetworkRouter{}
router.FromAPIRequest(&req)
router.NetworkID = mux.Vars(r)["networkId"]
router.ID = mux.Vars(r)["routerId"]
router.AccountID = accountID
router, err = h.routersManager.UpdateRouter(r.Context(), userID, router)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
}
func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.extractFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
routerID := mux.Vars(r)["routerId"]
networkID := mux.Vars(r)["networkId"]
err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, struct{}{})
}

View File

@ -17,6 +17,7 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
// Handler is a handler that returns peers of the account
@ -57,7 +58,7 @@ func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) {
return peerToReturn, nil
}
func (h *Handler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) {
func (h *Handler) getPeer(ctx context.Context, account *types.Account, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID)
if err != nil {
util.WriteError(ctx, err, w)
@ -84,7 +85,7 @@ func (h *Handler) getPeer(ctx context.Context, account *server.Account, peerID,
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
}
func (h *Handler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
func (h *Handler) updatePeer(ctx context.Context, account *types.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@ -295,7 +296,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}
func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer {
func toAccessiblePeers(netMap *types.NetworkMap, dnsDomain string) []api.AccessiblePeer {
accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers))
for _, p := range netMap.Peers {
accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain))

View File

@ -15,11 +15,11 @@ import (
"github.com/gorilla/mux"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/stretchr/testify/assert"
@ -73,18 +73,18 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) {
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
peersMap[peer.ID] = peer.Copy()
}
policy := &server.Policy{
policy := &types.Policy{
ID: "policy",
AccountID: accountID,
Name: "policy",
Enabled: true,
Rules: []*server.PolicyRule{
Rules: []*types.PolicyRule{
{
ID: "rule",
Name: "rule",
@ -99,16 +99,16 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
},
}
srvUser := server.NewRegularUser(serviceUser)
srvUser := types.NewRegularUser(serviceUser)
srvUser.IsServiceUser = true
account := &server.Account{
account := &types.Account{
Id: accountID,
Domain: "hotmail.com",
Peers: peersMap,
Users: map[string]*server.User{
adminUser: server.NewAdminUser(adminUser),
regularUser: server.NewRegularUser(regularUser),
Users: map[string]*types.User{
adminUser: types.NewAdminUser(adminUser),
regularUser: types.NewRegularUser(regularUser),
serviceUser: srvUser,
},
Groups: map[string]*nbgroup.Group{
@ -120,12 +120,12 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
Peers: maps.Keys(peersMap),
},
},
Settings: &server.Settings{
Settings: &types.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
Policies: []*server.Policy{policy},
Network: &server.Network{
Policies: []*types.Policy{policy},
Network: &types.Network{
Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{
IP: net.ParseIP("100.67.0.0"),

View File

@ -13,11 +13,11 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
@ -46,8 +46,8 @@ func initGeolocationTestData(t *testing.T) *geolocationsHandler {
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
return server.NewAdminUser(id), nil
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
return types.NewAdminUser(id), nil
},
},
geolocationManager: geo,

View File

@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that returns policy of the account
@ -133,7 +134,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
return
}
policy := &server.Policy{
policy := &types.Policy{
ID: policyID,
AccountID: accountID,
Name: req.Name,
@ -146,7 +147,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
ruleID = *rule.Id
}
pr := server.PolicyRule{
pr := types.PolicyRule{
ID: ruleID,
PolicyID: policyID,
Name: rule.Name,
@ -162,9 +163,9 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
switch rule.Action {
case api.PolicyRuleUpdateActionAccept:
pr.Action = server.PolicyTrafficActionAccept
pr.Action = types.PolicyTrafficActionAccept
case api.PolicyRuleUpdateActionDrop:
pr.Action = server.PolicyTrafficActionDrop
pr.Action = types.PolicyTrafficActionDrop
default:
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown action type"), w)
return
@ -172,13 +173,13 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
switch rule.Protocol {
case api.PolicyRuleUpdateProtocolAll:
pr.Protocol = server.PolicyRuleProtocolALL
pr.Protocol = types.PolicyRuleProtocolALL
case api.PolicyRuleUpdateProtocolTcp:
pr.Protocol = server.PolicyRuleProtocolTCP
pr.Protocol = types.PolicyRuleProtocolTCP
case api.PolicyRuleUpdateProtocolUdp:
pr.Protocol = server.PolicyRuleProtocolUDP
pr.Protocol = types.PolicyRuleProtocolUDP
case api.PolicyRuleUpdateProtocolIcmp:
pr.Protocol = server.PolicyRuleProtocolICMP
pr.Protocol = types.PolicyRuleProtocolICMP
default:
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w)
return
@ -205,7 +206,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w)
return
}
pr.PortRanges = append(pr.PortRanges, server.RulePortRange{
pr.PortRanges = append(pr.PortRanges, types.RulePortRange{
Start: uint16(portRange.Start),
End: uint16(portRange.End),
})
@ -214,7 +215,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
// validate policy object
switch pr.Protocol {
case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP:
case types.PolicyRuleProtocolALL, types.PolicyRuleProtocolICMP:
if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
return
@ -223,7 +224,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
return
}
case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP:
case types.PolicyRuleProtocolTCP, types.PolicyRuleProtocolUDP:
if !pr.Bidirectional && (len(pr.Ports) == 0 || len(pr.PortRanges) != 0) {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
return
@ -319,7 +320,7 @@ func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, resp)
}
func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy {
func toPolicyResponse(groups []*nbgroup.Group, policy *types.Policy) *api.Policy {
groupsMap := make(map[string]*nbgroup.Group)
for _, group := range groups {
groupsMap[group.ID] = group

View File

@ -13,6 +13,7 @@ import (
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/gorilla/mux"
@ -20,25 +21,24 @@ import (
"github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
)
func initPoliciesTestData(policies ...*server.Policy) *handler {
testPolicies := make(map[string]*server.Policy, len(policies))
func initPoliciesTestData(policies ...*types.Policy) *handler {
testPolicies := make(map[string]*types.Policy, len(policies))
for _, policy := range policies {
testPolicies[policy.ID] = policy
}
return &handler{
accountManager: &mock_server.MockAccountManager{
GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) {
GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*types.Policy, error) {
policy, ok := testPolicies[policyID]
if !ok {
return nil, status.Errorf(status.NotFound, "policy not found")
}
return policy, nil
},
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) {
SavePolicyFunc: func(_ context.Context, _, _ string, policy *types.Policy) (*types.Policy, error) {
if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set"
policy.Rules[0].ID = "id-was-set"
@ -51,19 +51,19 @@ func initPoliciesTestData(policies ...*server.Policy) *handler {
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
user := server.NewAdminUser(userID)
return &server.Account{
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) {
user := types.NewAdminUser(userID)
return &types.Account{
Id: accountID,
Domain: "hotmail.com",
Policies: []*server.Policy{
Policies: []*types.Policy{
{ID: "id-existed"},
},
Groups: map[string]*nbgroup.Group{
"F": {ID: "F"},
"G": {ID: "G"},
},
Users: map[string]*server.User{
Users: map[string]*types.User{
"test_user": user,
},
}, nil
@ -105,10 +105,10 @@ func TestPoliciesGetPolicy(t *testing.T) {
},
}
policy := &server.Policy{
policy := &types.Policy{
ID: "idofthepolicy",
Name: "Rule",
Rules: []*server.PolicyRule{
Rules: []*types.PolicyRule{
{ID: "idoftherule", Name: "Rule"},
},
}
@ -251,10 +251,10 @@ func TestPoliciesWritePolicy(t *testing.T) {
},
}
p := initPoliciesTestData(&server.Policy{
p := initPoliciesTestData(&types.Policy{
ID: "id-existed",
Name: "Default POSTed Rule",
Rules: []*server.PolicyRule{
Rules: []*types.PolicyRule{
{
ID: "id-existed",
Name: "Default POSTed Rule",

View File

@ -16,13 +16,13 @@ import (
"github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
"github.com/gorilla/mux"
"github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
)
@ -61,7 +61,7 @@ var baseExistingRoute = &route.Route{
Groups: []string{existingGroupID},
}
var testingAccount = &server.Account{
var testingAccount = &types.Account{
Id: testAccountID,
Domain: "hotmail.com",
Peers: map[string]*nbpeer.Peer{
@ -82,8 +82,8 @@ var testingAccount = &server.Account{
},
},
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
Users: map[string]*types.User{
"test_user": types.NewAdminUser("test_user"),
},
}

View File

@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
// handler is a handler that returns a list of setup keys of the account
@ -63,8 +64,8 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
return
}
if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable ||
server.SetupKeyType(req.Type) == server.SetupKeyOneOff) {
if !(types.SetupKeyType(req.Type) == types.SetupKeyReusable ||
types.SetupKeyType(req.Type) == types.SetupKeyOneOff) {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w)
return
}
@ -85,7 +86,7 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
ephemeral = *req.Ephemeral
}
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn,
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, types.SetupKeyType(req.Type), expiresIn,
req.AutoGroups, req.UsageLimit, userID, ephemeral)
if err != nil {
util.WriteError(r.Context(), err, w)
@ -152,7 +153,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
return
}
newKey := &server.SetupKey{}
newKey := &types.SetupKey{}
newKey.AutoGroups = req.AutoGroups
newKey.Revoked = req.Revoked
newKey.Id = keyID
@ -212,7 +213,7 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) {
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *types.SetupKey) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
err := json.NewEncoder(w).Encode(toResponseBody(key))
@ -222,7 +223,7 @@ func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupK
}
}
func toResponseBody(key *server.SetupKey) *api.SetupKey {
func toResponseBody(key *types.SetupKey) *api.SetupKey {
var state string
switch {
case key.IsExpired():

View File

@ -14,11 +14,11 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
const (
@ -29,17 +29,17 @@ const (
testAccountID = "test_id"
)
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey,
user *server.User,
func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey,
user *types.User,
) *handler {
return &handler{
accountManager: &mock_server.MockAccountManager{
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ types.SetupKeyType, _ time.Duration, _ []string,
_ int, _ string, ephemeral bool,
) (*server.SetupKey, error) {
) (*types.SetupKey, error) {
if keyName == newKey.Name || typ != newKey.Type {
nk := newKey.Copy()
nk.Ephemeral = ephemeral
@ -47,7 +47,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
}
return nil, fmt.Errorf("failed creating setup key")
},
GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*server.SetupKey, error) {
GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*types.SetupKey, error) {
switch keyID {
case defaultKey.Id:
return defaultKey, nil
@ -58,15 +58,15 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
}
},
SaveSetupKeyFunc: func(_ context.Context, accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) {
SaveSetupKeyFunc: func(_ context.Context, accountID string, key *types.SetupKey, _ string) (*types.SetupKey, error) {
if key.Id == updatedSetupKey.Id {
return updatedSetupKey, nil
}
return nil, status.Errorf(status.NotFound, "key %s not found", key.Id)
},
ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) {
return []*server.SetupKey{defaultKey}, nil
ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*types.SetupKey, error) {
return []*types.SetupKey{defaultKey}, nil
},
DeleteSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) error {
@ -89,13 +89,13 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
}
func TestSetupKeysHandlers(t *testing.T) {
defaultSetupKey, _ := server.GenerateDefaultSetupKey()
defaultSetupKey, _ := types.GenerateDefaultSetupKey()
defaultSetupKey.Id = existingSetupKeyID
adminUser := server.NewAdminUser("test_user")
adminUser := types.NewAdminUser("test_user")
newSetupKey, plainKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"},
server.SetupKeyUnlimitedUsage, true)
newSetupKey, plainKey := types.GenerateSetupKey(newSetupKeyName, types.SetupKeyReusable, 0, []string{"group-1"},
types.SetupKeyUnlimitedUsage, true)
newSetupKey.Key = plainKey
updatedDefaultSetupKey := defaultSetupKey.Copy()
updatedDefaultSetupKey.AutoGroups = []string{"group-1"}

View File

@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
// patHandler is the nameserver group handler of the account
@ -164,7 +165,7 @@ func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken {
func toPATResponse(pat *types.PersonalAccessToken) *api.PersonalAccessToken {
var lastUsed *time.Time
if !pat.LastUsed.IsZero() {
lastUsed = &pat.LastUsed
@ -179,7 +180,7 @@ func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken {
}
}
func toPATGeneratedResponse(pat *server.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated {
func toPATGeneratedResponse(pat *types.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated {
return &api.PersonalAccessTokenGenerated{
PlainToken: pat.PlainToken,
PersonalAccessToken: *toPATResponse(&pat.PersonalAccessToken),

View File

@ -14,11 +14,11 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
const (
@ -31,13 +31,13 @@ const (
testDomain = "hotmail.com"
)
var testAccount = &server.Account{
var testAccount = &types.Account{
Id: existingAccountID,
Domain: testDomain,
Users: map[string]*server.User{
Users: map[string]*types.User{
existingUserID: {
Id: existingUserID,
PATs: map[string]*server.PersonalAccessToken{
PATs: map[string]*types.PersonalAccessToken{
existingTokenID: {
ID: existingTokenID,
Name: "My first token",
@ -64,16 +64,16 @@ var testAccount = &server.Account{
func initPATTestData() *patHandler {
return &patHandler{
accountManager: &mock_server.MockAccountManager{
CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
if targetUserID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
}
return &server.PersonalAccessTokenGenerated{
return &types.PersonalAccessTokenGenerated{
PlainToken: "nbp_z1pvsg2wP3EzmEou4S679KyTNhov632eyrXe",
PersonalAccessToken: server.PersonalAccessToken{},
PersonalAccessToken: types.PersonalAccessToken{},
}, nil
},
@ -92,7 +92,7 @@ func initPATTestData() *patHandler {
}
return nil
},
GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
@ -104,14 +104,14 @@ func initPATTestData() *patHandler {
}
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
},
GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
if targetUserID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
}
return []*server.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil
return []*types.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
@ -217,7 +217,7 @@ func TestTokenHandlers(t *testing.T) {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.NotEmpty(t, got.PlainToken)
assert.Equal(t, server.PATLength, len(got.PlainToken))
assert.Equal(t, types.PATLength, len(got.PlainToken))
case "Get All Tokens":
expectedTokens := []api.PersonalAccessToken{
toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]),
@ -243,7 +243,7 @@ func TestTokenHandlers(t *testing.T) {
}
}
func toTokenResponse(serverToken server.PersonalAccessToken) api.PersonalAccessToken {
func toTokenResponse(serverToken types.PersonalAccessToken) api.PersonalAccessToken {
return api.PersonalAccessToken{
Id: serverToken.ID,
Name: serverToken.Name,

View File

@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@ -83,13 +84,13 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
return
}
userRole := server.StrRoleToUserRole(req.Role)
if userRole == server.UserRoleUnknown {
userRole := types.StrRoleToUserRole(req.Role)
if userRole == types.UserRoleUnknown {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user role"), w)
return
}
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &types.User{
Id: targetUserID,
Role: userRole,
AutoGroups: req.AutoGroups,
@ -156,7 +157,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
return
}
if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown {
if types.StrRoleToUserRole(req.Role) == types.UserRoleUnknown {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w)
return
}
@ -171,13 +172,13 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
name = *req.Name
}
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &types.UserInfo{
Email: email,
Name: name,
Role: req.Role,
AutoGroups: req.AutoGroups,
IsServiceUser: req.IsServiceUser,
Issued: server.UserIssuedAPI,
Issued: types.UserIssuedAPI,
})
if err != nil {
util.WriteError(r.Context(), err, w)
@ -264,7 +265,7 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
autoGroups := user.AutoGroups
if autoGroups == nil {
autoGroups = []string{}

View File

@ -13,11 +13,11 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
const (
@ -26,37 +26,37 @@ const (
regularUserID = "regularUserID"
)
var usersTestAccount = &server.Account{
var usersTestAccount = &types.Account{
Id: existingAccountID,
Domain: testDomain,
Users: map[string]*server.User{
Users: map[string]*types.User{
existingUserID: {
Id: existingUserID,
Role: "admin",
IsServiceUser: false,
AutoGroups: []string{"group_1"},
Issued: server.UserIssuedAPI,
Issued: types.UserIssuedAPI,
},
regularUserID: {
Id: regularUserID,
Role: "user",
IsServiceUser: false,
AutoGroups: []string{"group_1"},
Issued: server.UserIssuedAPI,
Issued: types.UserIssuedAPI,
},
serviceUserID: {
Id: serviceUserID,
Role: "user",
IsServiceUser: true,
AutoGroups: []string{"group_1"},
Issued: server.UserIssuedAPI,
Issued: types.UserIssuedAPI,
},
nonDeletableServiceUserID: {
Id: serviceUserID,
Role: "admin",
IsServiceUser: true,
NonDeletable: true,
Issued: server.UserIssuedIntegration,
Issued: types.UserIssuedIntegration,
},
},
}
@ -67,13 +67,13 @@ func initUsersTestData() *handler {
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return usersTestAccount.Id, claims.UserId, nil
},
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
return usersTestAccount.Users[id], nil
},
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0)
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) {
users := make([]*types.UserInfo, 0)
for _, v := range usersTestAccount.Users {
users = append(users, &server.UserInfo{
users = append(users, &types.UserInfo{
ID: v.Id,
Role: string(v.Role),
Name: "",
@ -85,7 +85,7 @@ func initUsersTestData() *handler {
}
return users, nil
},
CreateUserFunc: func(_ context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) {
CreateUserFunc: func(_ context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) {
if userID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
}
@ -100,7 +100,7 @@ func initUsersTestData() *handler {
}
return nil
},
SaveUserFunc: func(_ context.Context, accountID, userID string, update *server.User) (*server.UserInfo, error) {
SaveUserFunc: func(_ context.Context, accountID, userID string, update *types.User) (*types.UserInfo, error) {
if update.Id == notFoundUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id)
}
@ -109,7 +109,7 @@ func initUsersTestData() *handler {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
}
info, err := update.Copy().ToUserInfo(nil, &server.Settings{RegularUsersViewBlocked: false})
info, err := update.Copy().ToUserInfo(nil, &types.Settings{RegularUsersViewBlocked: false})
if err != nil {
return nil, err
}
@ -175,7 +175,7 @@ func TestGetUsers(t *testing.T) {
return
}
respBody := []*server.UserInfo{}
respBody := []*types.UserInfo{}
err = json.Unmarshal(content, &respBody)
if err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
@ -342,7 +342,7 @@ func TestCreateUser(t *testing.T) {
requestType string
requestPath string
requestBody io.Reader
expectedResult []*server.User
expectedResult []*types.User
}{
{name: "CreateServiceUser", requestType: http.MethodPost, requestPath: "/api/users", expectedStatus: http.StatusOK, requestBody: bytes.NewBuffer(serviceUserString)},
// right now creation is blocked in AC middleware, will be refactored in the future

View File

@ -7,16 +7,16 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/jwtclaims"
)
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct {

View File

@ -11,16 +11,16 @@ import (
"github.com/golang-jwt/jwt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
)
// GetAccountFromPATFunc function
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
type GetAccountFromPATFunc func(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
// ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)

View File

@ -10,9 +10,9 @@ import (
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/types"
)
const (
@ -28,13 +28,13 @@ const (
wrongToken = "wrongToken"
)
var testAccount = &server.Account{
var testAccount = &types.Account{
Id: accountID,
Domain: domain,
Users: map[string]*server.User{
Users: map[string]*types.User{
userID: {
Id: userID,
PATs: map[string]*server.PersonalAccessToken{
PATs: map[string]*types.PersonalAccessToken{
tokenID: {
ID: tokenID,
Name: "My first token",
@ -49,7 +49,7 @@ var testAccount = &server.Account{
},
}
func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
func mockGetAccountFromPAT(_ context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
if token == PAT {
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
}

View File

@ -7,6 +7,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
// UpdateIntegratedValidatorGroups updates the integrated validator groups for a specified account.
@ -57,9 +59,9 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
return true, nil
}
err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range groupIDs {
_, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
_, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
@ -73,6 +75,6 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
return true, nil
}
func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) {
func (am *DefaultAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) {
return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra)
}

View File

@ -23,7 +23,9 @@ import (
"github.com/netbirdio/netbird/formatter"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
@ -413,7 +415,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
if err != nil {
t.Fatal(err)
}
@ -618,7 +620,7 @@ func testSyncStatusRace(t *testing.T) {
}
time.Sleep(10 * time.Millisecond)
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
if err != nil {
t.Fatal(err)
return
@ -705,7 +707,7 @@ func Test_LoginPerformance(t *testing.T) {
return
}
setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false)
setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), types.SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false)
if err != nil {
t.Logf("error creating setup key: %v", err)
return

View File

@ -25,6 +25,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util"
)
@ -532,7 +533,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc.
Expect(err).NotTo(HaveOccurred())
s := grpc.NewServer()
store, _, err := server.NewTestStoreFromSQL(context.Background(), testFile, dataDir)
store, _, err := store.NewTestStoreFromSQL(context.Background(), testFile, dataDir)
if err != nil {
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
}

View File

@ -15,7 +15,8 @@ import (
"github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
nbversion "github.com/netbirdio/netbird/version"
)
@ -47,8 +48,8 @@ type properties map[string]interface{}
// DataSource metric data source
type DataSource interface {
GetAllAccounts(ctx context.Context) []*server.Account
GetStoreEngine() server.StoreEngine
GetAllAccounts(ctx context.Context) []*types.Account
GetStoreEngine() store.Engine
}
// ConnManager peer connection manager that holds state for current active connections

View File

@ -5,10 +5,11 @@ import (
"testing"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
@ -22,12 +23,12 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} {
}
// GetAllAccounts returns a list of *server.Account for use in tests with predefined information
func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account {
return []*server.Account{
func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
return []*types.Account{
{
Id: "1",
Settings: &server.Settings{PeerLoginExpirationEnabled: true},
SetupKeys: map[string]*server.SetupKey{
Settings: &types.Settings{PeerLoginExpirationEnabled: true},
SetupKeys: map[string]*types.SetupKey{
"1": {
Id: "1",
Ephemeral: true,
@ -49,20 +50,20 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account {
Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"},
},
},
Policies: []*server.Policy{
Policies: []*types.Policy{
{
Rules: []*server.PolicyRule{
Rules: []*types.PolicyRule{
{
Bidirectional: true,
Protocol: server.PolicyRuleProtocolTCP,
Protocol: types.PolicyRuleProtocolTCP,
},
},
},
{
Rules: []*server.PolicyRule{
Rules: []*types.PolicyRule{
{
Bidirectional: false,
Protocol: server.PolicyRuleProtocolTCP,
Protocol: types.PolicyRuleProtocolTCP,
},
},
SourcePostureChecks: []string{"1"},
@ -94,16 +95,16 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account {
},
},
},
Users: map[string]*server.User{
Users: map[string]*types.User{
"1": {
IsServiceUser: true,
PATs: map[string]*server.PersonalAccessToken{
PATs: map[string]*types.PersonalAccessToken{
"1": {},
},
},
"2": {
IsServiceUser: false,
PATs: map[string]*server.PersonalAccessToken{
PATs: map[string]*types.PersonalAccessToken{
"1": {},
},
},
@ -111,8 +112,8 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account {
},
{
Id: "2",
Settings: &server.Settings{PeerLoginExpirationEnabled: true},
SetupKeys: map[string]*server.SetupKey{
Settings: &types.Settings{PeerLoginExpirationEnabled: true},
SetupKeys: map[string]*types.SetupKey{
"1": {
Id: "1",
Ephemeral: true,
@ -134,20 +135,20 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account {
Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"},
},
},
Policies: []*server.Policy{
Policies: []*types.Policy{
{
Rules: []*server.PolicyRule{
Rules: []*types.PolicyRule{
{
Bidirectional: true,
Protocol: server.PolicyRuleProtocolTCP,
Protocol: types.PolicyRuleProtocolTCP,
},
},
},
{
Rules: []*server.PolicyRule{
Rules: []*types.PolicyRule{
{
Bidirectional: false,
Protocol: server.PolicyRuleProtocolTCP,
Protocol: types.PolicyRuleProtocolTCP,
},
},
},
@ -158,16 +159,16 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account {
PeerGroups: make([]string, 1),
},
},
Users: map[string]*server.User{
Users: map[string]*types.User{
"1": {
IsServiceUser: true,
PATs: map[string]*server.PersonalAccessToken{
PATs: map[string]*types.PersonalAccessToken{
"1": {},
},
},
"2": {
IsServiceUser: false,
PATs: map[string]*server.PersonalAccessToken{
PATs: map[string]*types.PersonalAccessToken{
"1": {},
},
},
@ -177,8 +178,8 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account {
}
// GetStoreEngine returns FileStoreEngine
func (mockDatasource) GetStoreEngine() server.StoreEngine {
return server.FileStoreEngine
func (mockDatasource) GetStoreEngine() store.Engine {
return store.FileStoreEngine
}
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
@ -267,7 +268,7 @@ func TestGenerateProperties(t *testing.T) {
t.Errorf("expected 2 user_peers, got %d", properties["user_peers"])
}
if properties["store_engine"] != server.FileStoreEngine {
if properties["store_engine"] != store.FileStoreEngine {
t.Errorf("expected JsonFile, got %s", properties["store_engine"])
}

View File

@ -12,9 +12,9 @@ import (
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/migration"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
@ -31,64 +31,64 @@ func setupDatabase(t *testing.T) *gorm.DB {
func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) {
db := setupDatabase(t)
err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net")
err := migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](context.Background(), db, "network_net")
require.NoError(t, err, "Migration should not fail for an empty database")
}
func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.Account{}, &route.Route{})
err := db.AutoMigrate(&types.Account{}, &route.Route{})
require.NoError(t, err, "Failed to auto-migrate tables")
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
require.NoError(t, err, "Failed to parse CIDR")
type network struct {
server.Network
types.Network
Net net.IPNet `gorm:"serializer:gob"`
}
type account struct {
server.Account
types.Account
Network *network `gorm:"embedded;embeddedPrefix:network_"`
}
err = db.Save(&account{Account: server.Account{Id: "123"}, Network: &network{Net: *ipnet}}).Error
err = db.Save(&account{Account: types.Account{Id: "123"}, Network: &network{Net: *ipnet}}).Error
require.NoError(t, err, "Failed to insert Gob data")
var gobStr string
err = db.Model(&server.Account{}).Select("network_net").First(&gobStr).Error
err = db.Model(&types.Account{}).Select("network_net").First(&gobStr).Error
assert.NoError(t, err, "Failed to fetch Gob data")
err = gob.NewDecoder(strings.NewReader(gobStr)).Decode(&ipnet)
require.NoError(t, err, "Failed to decode Gob data")
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net")
err = migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](context.Background(), db, "network_net")
require.NoError(t, err, "Migration should not fail with Gob data")
var jsonStr string
db.Model(&server.Account{}).Select("network_net").First(&jsonStr)
db.Model(&types.Account{}).Select("network_net").First(&jsonStr)
assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be migrated")
}
func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.Account{}, &route.Route{})
err := db.AutoMigrate(&types.Account{}, &route.Route{})
require.NoError(t, err, "Failed to auto-migrate tables")
_, ipnet, err := net.ParseCIDR("10.0.0.0/24")
require.NoError(t, err, "Failed to parse CIDR")
err = db.Save(&server.Account{Network: &server.Network{Net: *ipnet}}).Error
err = db.Save(&types.Account{Network: &types.Network{Net: *ipnet}}).Error
require.NoError(t, err, "Failed to insert JSON data")
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net")
err = migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](context.Background(), db, "network_net")
require.NoError(t, err, "Migration should not fail with JSON data")
var jsonStr string
db.Model(&server.Account{}).Select("network_net").First(&jsonStr)
db.Model(&types.Account{}).Select("network_net").First(&jsonStr)
assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be unchanged")
}
@ -101,7 +101,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) {
func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{})
err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{})
require.NoError(t, err, "Failed to auto-migrate tables")
type location struct {
@ -115,12 +115,12 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
}
type account struct {
server.Account
types.Account
Peers []peer `gorm:"foreignKey:AccountID;references:id"`
}
err = db.Save(&account{
Account: server.Account{Id: "123"},
Account: types.Account{Id: "123"},
Peers: []peer{
{Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}},
}},
@ -142,10 +142,10 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{})
err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{})
require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&server.Account{
err = db.Save(&types.Account{
Id: "1234",
PeersG: []nbpeer.Peer{
{Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}},
@ -164,20 +164,20 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.SetupKey{})
err := db.AutoMigrate(&types.SetupKey{})
require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&server.SetupKey{
err = db.Save(&types.SetupKey{
Id: "1",
Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382",
}).Error
require.NoError(t, err, "Failed to insert setup key")
err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db)
err = migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](context.Background(), db)
require.NoError(t, err, "Migration should not fail to migrate setup key")
var key server.SetupKey
err = db.Model(&server.SetupKey{}).First(&key).Error
var key types.SetupKey
err = db.Model(&types.SetupKey{}).First(&key).Error
assert.NoError(t, err, "Failed to fetch setup key")
assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret")
@ -187,21 +187,21 @@ func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) {
func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.SetupKey{})
err := db.AutoMigrate(&types.SetupKey{})
require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&server.SetupKey{
err = db.Save(&types.SetupKey{
Id: "1",
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
KeySecret: "EEFDA****",
}).Error
require.NoError(t, err, "Failed to insert setup key")
err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db)
err = migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](context.Background(), db)
require.NoError(t, err, "Migration should not fail to migrate setup key")
var key server.SetupKey
err = db.Model(&server.SetupKey{}).First(&key).Error
var key types.SetupKey
err = db.Model(&types.SetupKey{}).First(&key).Error
assert.NoError(t, err, "Failed to fetch setup key")
assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret")
@ -211,20 +211,20 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.
func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.T) {
db := setupDatabase(t)
err := db.AutoMigrate(&server.SetupKey{})
err := db.AutoMigrate(&types.SetupKey{})
require.NoError(t, err, "Failed to auto-migrate tables")
err = db.Save(&server.SetupKey{
err = db.Save(&types.SetupKey{
Id: "1",
Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=",
}).Error
require.NoError(t, err, "Failed to insert setup key")
err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db)
err = migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](context.Background(), db)
require.NoError(t, err, "Migration should not fail to migrate setup key")
var key server.SetupKey
err = db.Model(&server.SetupKey{}).First(&key).Error
var key types.SetupKey
err = db.Model(&types.SetupKey{}).First(&key).Error
assert.NoError(t, err, "Failed to fetch setup key")
assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed")

View File

@ -16,28 +16,30 @@ import (
"github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/networks"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error)
GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error)
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*types.Account, error)
GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error)
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType types.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*types.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error)
AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*group.Group, error)
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error)
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error)
@ -48,12 +50,12 @@ type MockAccountManager struct {
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error)
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error)
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error)
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error)
GetAccountFromPATFunc func(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error)
MarkPATUsedFunc func(ctx context.Context, pat string) error
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
@ -62,35 +64,35 @@ type MockAccountManager struct {
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error)
SaveSetupKeyFunc func(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error)
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error)
SaveSetupKeyFunc func(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error)
ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error)
SaveUserFunc func(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error)
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error)
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error)
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error)
GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error)
GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error)
GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
GetDNSDomainFunc func() string
StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error)
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error)
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error)
LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error)
LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
HasConnectedChannelFunc func(peerID string) bool
@ -105,12 +107,17 @@ type MockAccountManager struct {
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error)
GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error)
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error)
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error)
GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error)
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error)
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
}
func (am *MockAccountManager) GetNetworksManager() networks.Manager {
// TODO implement me
panic("implement me")
}
func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
if am.DeleteSetupKeyFunc != nil {
return am.DeleteSetupKeyFunc(ctx, accountID, userID, keyID)
@ -118,7 +125,7 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
}
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
}
@ -130,7 +137,7 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st
panic("implement me")
}
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) {
func (am *MockAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) {
approvedPeers := make(map[string]struct{})
for id := range account.Peers {
approvedPeers[id] = struct{}{}
@ -155,7 +162,7 @@ func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userI
}
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*server.UserInfo, error) {
func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*types.UserInfo, error) {
if am.GetUsersFromAccountFunc != nil {
return am.GetUsersFromAccountFunc(ctx, accountID, userID)
}
@ -173,7 +180,7 @@ func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID,
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
func (am *MockAccountManager) GetOrCreateAccountByUser(
ctx context.Context, userId, domain string,
) (*server.Account, error) {
) (*types.Account, error) {
if am.GetOrCreateAccountByUserFunc != nil {
return am.GetOrCreateAccountByUserFunc(ctx, userId, domain)
}
@ -188,13 +195,13 @@ func (am *MockAccountManager) CreateSetupKey(
ctx context.Context,
accountID string,
keyName string,
keyType server.SetupKeyType,
keyType types.SetupKeyType,
expiresIn time.Duration,
autoGroups []string,
usageLimit int,
userID string,
ephemeral bool,
) (*server.SetupKey, error) {
) (*types.SetupKey, error) {
if am.CreateSetupKeyFunc != nil {
return am.CreateSetupKeyFunc(ctx, accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID, ephemeral)
}
@ -221,7 +228,7 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId,
}
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *server.Account) error {
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *types.Account) error {
if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP)
}
@ -229,7 +236,7 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str
}
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) {
if am.GetAccountFromPATFunc != nil {
return am.GetAccountFromPATFunc(ctx, pat)
}
@ -253,7 +260,7 @@ func (am *MockAccountManager) MarkPATUsed(ctx context.Context, pat string) error
}
// CreatePAT mock implementation of GetPAT from server.AccountManager interface
func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
if am.CreatePATFunc != nil {
return am.CreatePATFunc(ctx, accountID, initiatorUserID, targetUserID, name, expiresIn)
}
@ -269,7 +276,7 @@ func (am *MockAccountManager) DeletePAT(ctx context.Context, accountID string, i
}
// GetPAT mock implementation of GetPAT from server.AccountManager interface
func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) {
if am.GetPATFunc != nil {
return am.GetPATFunc(ctx, accountID, initiatorUserID, targetUserID, tokenID)
}
@ -277,7 +284,7 @@ func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, init
}
// GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface
func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) {
if am.GetAllPATsFunc != nil {
return am.GetAllPATsFunc(ctx, accountID, initiatorUserID, targetUserID)
}
@ -285,7 +292,7 @@ func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string,
}
// GetNetworkMap mock implementation of GetNetworkMap from server.AccountManager interface
func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) (*server.NetworkMap, error) {
func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) (*types.NetworkMap, error) {
if am.GetNetworkMapFunc != nil {
return am.GetNetworkMapFunc(ctx, peerKey)
}
@ -293,7 +300,7 @@ func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string)
}
// GetPeerNetwork mock implementation of GetPeerNetwork from server.AccountManager interface
func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string) (*server.Network, error) {
func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string) (*types.Network, error) {
if am.GetPeerNetworkFunc != nil {
return am.GetPeerNetworkFunc(ctx, peerKey)
}
@ -306,7 +313,7 @@ func (am *MockAccountManager) AddPeer(
setupKey string,
userId string,
peer *nbpeer.Peer,
) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
if am.AddPeerFunc != nil {
return am.AddPeerFunc(ctx, setupKey, userId, peer)
}
@ -378,7 +385,7 @@ func (am *MockAccountManager) DeleteRule(ctx context.Context, accountID, ruleID,
}
// GetPolicy mock implementation of GetPolicy from server.AccountManager interface
func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) {
func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) {
if am.GetPolicyFunc != nil {
return am.GetPolicyFunc(ctx, accountID, policyID, userID)
}
@ -386,7 +393,7 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
}
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) {
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) {
if am.SavePolicyFunc != nil {
return am.SavePolicyFunc(ctx, accountID, userID, policy)
}
@ -402,7 +409,7 @@ func (am *MockAccountManager) DeletePolicy(ctx context.Context, accountID, polic
}
// ListPolicies mock implementation of ListPolicies from server.AccountManager interface
func (am *MockAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*server.Policy, error) {
func (am *MockAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) {
if am.ListPoliciesFunc != nil {
return am.ListPoliciesFunc(ctx, accountID, userID)
}
@ -418,14 +425,14 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string,
}
// GetUser mock implementation of GetUser from server.AccountManager interface
func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) {
func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) {
if am.GetUserFunc != nil {
return am.GetUserFunc(ctx, claims)
}
return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented")
}
func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*server.User, error) {
func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) {
if am.ListUsersFunc != nil {
return am.ListUsersFunc(ctx, accountID)
}
@ -481,7 +488,7 @@ func (am *MockAccountManager) ListRoutes(ctx context.Context, accountID, userID
}
// SaveSetupKey mocks SaveSetupKey of the AccountManager interface
func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) {
func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) {
if am.SaveSetupKeyFunc != nil {
return am.SaveSetupKeyFunc(ctx, accountID, key, userID)
}
@ -490,7 +497,7 @@ func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string
}
// GetSetupKey mocks GetSetupKey of the AccountManager interface
func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) {
func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) {
if am.GetSetupKeyFunc != nil {
return am.GetSetupKeyFunc(ctx, accountID, userID, keyID)
}
@ -499,7 +506,7 @@ func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID
}
// ListSetupKeys mocks ListSetupKeys of the AccountManager interface
func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) {
func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) {
if am.ListSetupKeysFunc != nil {
return am.ListSetupKeysFunc(ctx, accountID, userID)
}
@ -508,7 +515,7 @@ func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, user
}
// SaveUser mocks SaveUser of the AccountManager interface
func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) {
func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error) {
if am.SaveUserFunc != nil {
return am.SaveUserFunc(ctx, accountID, userID, user)
}
@ -516,7 +523,7 @@ func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID st
}
// SaveOrAddUser mocks SaveOrAddUser of the AccountManager interface
func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) {
func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error) {
if am.SaveOrAddUserFunc != nil {
return am.SaveOrAddUserFunc(ctx, accountID, userID, user, addIfNotExists)
}
@ -524,7 +531,7 @@ func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, user
}
// SaveOrAddUsers mocks SaveOrAddUsers of the AccountManager interface
func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) {
func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) {
if am.SaveOrAddUsersFunc != nil {
return am.SaveOrAddUsersFunc(ctx, accountID, userID, users, addIfNotExists)
}
@ -595,7 +602,7 @@ func (am *MockAccountManager) ListNameServerGroups(ctx context.Context, accountI
}
// CreateUser mocks CreateUser of the AccountManager interface
func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID string, invite *server.UserInfo) (*server.UserInfo, error) {
func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID string, invite *types.UserInfo) (*types.UserInfo, error) {
if am.CreateUserFunc != nil {
return am.CreateUserFunc(ctx, accountID, userID, invite)
}
@ -642,7 +649,7 @@ func (am *MockAccountManager) GetEvents(ctx context.Context, accountID, userID s
}
// GetDNSSettings mocks GetDNSSettings of the AccountManager interface
func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) {
func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
if am.GetDNSSettingsFunc != nil {
return am.GetDNSSettingsFunc(ctx, accountID, userID)
}
@ -650,7 +657,7 @@ func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID stri
}
// SaveDNSSettings mocks SaveDNSSettings of the AccountManager interface
func (am *MockAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
func (am *MockAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error {
if am.SaveDNSSettingsFunc != nil {
return am.SaveDNSSettingsFunc(ctx, accountID, userID, dnsSettingsToSave)
}
@ -666,7 +673,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us
}
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) {
if am.UpdateAccountSettingsFunc != nil {
return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings)
}
@ -674,7 +681,7 @@ func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, account
}
// LoginPeer mocks LoginPeer of the AccountManager interface
func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
if am.LoginPeerFunc != nil {
return am.LoginPeerFunc(ctx, login)
}
@ -682,7 +689,7 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLo
}
// SyncPeer mocks SyncPeer of the AccountManager interface
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
if am.SyncPeerFunc != nil {
return am.SyncPeerFunc(ctx, sync, account)
}
@ -803,7 +810,7 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe
}
// GetAccountByID mocks GetAccountByID of the AccountManager interface
func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) {
func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) {
if am.GetAccountByIDFunc != nil {
return am.GetAccountByIDFunc(ctx, accountID, userID)
}
@ -811,21 +818,21 @@ func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID stri
}
// GetUserByID mocks GetUserByID of the AccountManager interface
func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) {
func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
if am.GetUserByIDFunc != nil {
return am.GetUserByIDFunc(ctx, id)
}
return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented")
}
func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
if am.GetAccountSettingsFunc != nil {
return am.GetAccountSettingsFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented")
}
func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*server.Account, error) {
func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*types.Account, error) {
if am.GetAccountFunc != nil {
return am.GetAccountFunc(ctx, accountID)
}

View File

@ -13,13 +13,14 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
)
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
@ -32,7 +33,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
return nil, status.NewAdminPermissionError()
}
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupID)
}
// CreateNameServerGroup creates and saves a new nameserver group
@ -40,7 +41,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
@ -64,7 +65,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil {
return err
}
@ -74,11 +75,11 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup)
return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, newNSGroup)
})
if err != nil {
return nil, err
@ -102,7 +103,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return err
}
@ -113,8 +114,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID)
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupToSave.ID)
if err != nil {
return err
}
@ -129,11 +130,11 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave)
return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, nsGroupToSave)
})
if err != nil {
return err
@ -153,7 +154,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return err
}
@ -165,8 +166,8 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
var nsGroup *nbdns.NameServerGroup
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
nsGroup, err = transaction.GetNameServerGroupByID(ctx, LockingStrengthUpdate, accountID, nsGroupID)
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
nsGroup, err = transaction.GetNameServerGroupByID(ctx, store.LockingStrengthUpdate, accountID, nsGroupID)
if err != nil {
return err
}
@ -176,11 +177,11 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID)
return transaction.DeleteNameServerGroup(ctx, store.LockingStrengthUpdate, accountID, nsGroupID)
})
if err != nil {
return err
@ -197,7 +198,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
// ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
@ -210,10 +211,10 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
}
func validateNameServerGroup(ctx context.Context, transaction Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
if err != nil {
return err
@ -224,7 +225,7 @@ func validateNameServerGroup(ctx context.Context, transaction Store, accountID s
return err
}
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return err
}
@ -234,7 +235,7 @@ func validateNameServerGroup(ctx context.Context, transaction Store, accountID s
return err
}
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, nameserverGroup.Groups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, nameserverGroup.Groups)
if err != nil {
return err
}
@ -243,7 +244,7 @@ func validateNameServerGroup(ctx context.Context, transaction Store, accountID s
}
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store.Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false, nil
}

View File

@ -13,7 +13,9 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
)
const (
@ -772,10 +774,10 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
}
func createNSStore(t *testing.T) (Store, error) {
func createNSStore(t *testing.T) (store.Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir)
if err != nil {
return nil, err
}
@ -784,7 +786,7 @@ func createNSStore(t *testing.T) (Store, error) {
return store, nil
}
func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, error) {
t.Helper()
peer1 := &nbpeer.Peer{
Key: nsGroupPeer1Key,

View File

@ -0,0 +1,63 @@
package networks
import (
"context"
"errors"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/store"
)
type Manager interface {
GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error)
CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error)
GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error)
UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error)
DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error
GetResourceManager() resources.Manager
GetRouterManager() routers.Manager
}
type managerImpl struct {
store store.Store
routersManager routers.Manager
resourcesManager resources.Manager
}
func NewManager(store store.Store) Manager {
return &managerImpl{
store: store,
routersManager: routers.NewManager(store),
resourcesManager: resources.NewManager(store),
}
}
func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) {
return nil, errors.New("not implemented")
}
func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
return nil, errors.New("not implemented")
}
func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) {
return nil, errors.New("not implemented")
}
func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
return nil, errors.New("not implemented")
}
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
return errors.New("not implemented")
}
func (m *managerImpl) GetResourceManager() resources.Manager {
return m.resourcesManager
}
func (m *managerImpl) GetRouterManager() routers.Manager {
return m.routersManager
}

View File

@ -0,0 +1,47 @@
package resources
import (
"context"
"errors"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/management/server/store"
)
type Manager interface {
GetAllResources(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error)
CreateResource(ctx context.Context, accountID string, resource *types.NetworkResource) (*types.NetworkResource, error)
GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error)
UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error)
DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error
}
type managerImpl struct {
store store.Store
}
func NewManager(store store.Store) Manager {
return &managerImpl{
store: store,
}
}
func (m *managerImpl) GetAllResources(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) {
return nil, errors.New("not implemented")
}
func (m *managerImpl) CreateResource(ctx context.Context, accountID string, resource *types.NetworkResource) (*types.NetworkResource, error) {
return nil, errors.New("not implemented")
}
func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) {
return nil, errors.New("not implemented")
}
func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) {
return nil, errors.New("not implemented")
}
func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error {
return errors.New("not implemented")
}

View File

@ -1,4 +1,4 @@
package networks
package types
import (
"errors"
@ -8,14 +8,16 @@ import (
"strings"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/http/api"
)
type NetworkResourceType string
const (
host NetworkResourceType = "Host"
subnet NetworkResourceType = "Subnet"
domain NetworkResourceType = "Domain"
host NetworkResourceType = "host"
subnet NetworkResourceType = "subnet"
domain NetworkResourceType = "domain"
)
func (p NetworkResourceType) String() string {
@ -49,6 +51,25 @@ func NewNetworkResource(accountID, networkID, name, description, address string)
}, nil
}
func (n *NetworkResource) ToAPIResponse() *api.NetworkResource {
return &api.NetworkResource{
Id: n.ID,
Name: n.Name,
Description: &n.Description,
Type: api.NetworkResourceType(n.Type.String()),
Address: n.Address,
}
}
func (n *NetworkResource) FromAPIRequest(req *api.NetworkResourceRequest) {
n.Name = req.Name
n.Description = ""
if req.Description != nil {
n.Description = *req.Description
}
n.Address = req.Address
}
func (n *NetworkResource) Copy() *NetworkResource {
return &NetworkResource{
ID: n.ID,

View File

@ -1,4 +1,4 @@
package networks
package types
import (
"testing"

View File

@ -0,0 +1,47 @@
package routers
import (
"context"
"errors"
"github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/store"
)
type Manager interface {
GetAllRouters(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error)
CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error)
GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error)
UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error)
DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error
}
type managerImpl struct {
store store.Store
}
func NewManager(store store.Store) Manager {
return &managerImpl{
store: store,
}
}
func (m *managerImpl) GetAllRouters(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error) {
return nil, errors.New("not implemented")
}
func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) {
return nil, errors.New("not implemented")
}
func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) {
return nil, errors.New("not implemented")
}
func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) {
return nil, errors.New("not implemented")
}
func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error {
return errors.New("not implemented")
}

View File

@ -1,9 +1,11 @@
package networks
package types
import (
"errors"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/http/api"
)
type NetworkRouter struct {
@ -32,6 +34,23 @@ func NewNetworkRouter(accountID string, networkID string, peer string, peerGroup
}, nil
}
func (n *NetworkRouter) ToAPIResponse() *api.NetworkRouter {
return &api.NetworkRouter{
Id: n.ID,
Peer: &n.Peer,
PeerGroups: &n.PeerGroups,
Masquerade: n.Masquerade,
Metric: n.Metric,
}
}
func (n *NetworkRouter) FromAPIRequest(req *api.NetworkRouterRequest) {
n.Peer = *req.Peer
n.PeerGroups = *req.PeerGroups
n.Masquerade = req.Masquerade
n.Metric = req.Metric
}
func (n *NetworkRouter) Copy() *NetworkRouter {
return &NetworkRouter{
ID: n.ID,

View File

@ -1,4 +1,4 @@
package networks
package types
import "testing"

View File

@ -1,6 +1,10 @@
package networks
package types
import "github.com/rs/xid"
import (
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/http/api"
)
type Network struct {
ID string `gorm:"index"`
@ -18,6 +22,19 @@ func NewNetwork(accountId, name, description string) *Network {
}
}
func (n *Network) ToAPIResponse() *api.Network {
return &api.Network{
Id: n.ID,
Name: n.Name,
Description: &n.Description,
}
}
func (n *Network) FromAPIRequest(req *api.NetworkRequest) {
n.Name = req.Name
n.Description = *req.Description
}
// Copy returns a copy of a posture checks.
func (n *Network) Copy() *Network {
return &Network{

View File

@ -16,6 +16,8 @@ import (
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
@ -92,7 +94,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
// fetch all the peers that have access to the user's peers
for _, peer := range peers {
aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
for _, p := range aclPeers {
peersMap[p.ID] = p
}
@ -107,7 +109,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
}
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error {
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *types.Account) error {
peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil {
return fmt.Errorf("failed to find peer by pub key: %w", err)
@ -139,7 +141,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
return nil
}
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) {
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *types.Account) (bool, error) {
oldStatus := peer.Status.Copy()
newStatus := oldStatus
newStatus.LastSeen = time.Now().UTC()
@ -213,9 +215,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
if peerLabelUpdated {
peer.Name = update.Name
existingLabels := account.getPeerDNSLabels()
existingLabels := account.GetPeerDNSLabels()
newLabel, err := getPeerHostLabel(peer.Name, existingLabels)
newLabel, err := types.GetPeerHostLabel(peer.Name, existingLabels)
if err != nil {
return nil, err
}
@ -278,7 +280,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
// deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock
func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Account, peerIDs []string, userID string) error {
func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *types.Account, peerIDs []string, userID string) error {
// the first loop is needed to ensure all peers present under the account before modifying, otherwise
// we might have some inconsistencies
@ -316,7 +318,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou
FirewallRulesIsEmpty: true,
},
},
NetworkMap: &NetworkMap{},
NetworkMap: &types.NetworkMap{},
})
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
@ -358,7 +360,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
}
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) {
func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) {
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
if err != nil {
return nil, err
@ -383,7 +385,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
}
// GetPeerNetwork returns the Network for a given peer
func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID string) (*Network, error) {
func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) {
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
if err != nil {
return nil, err
@ -399,7 +401,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri
// to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
// The peer property is just a placeholder for the Peer properties to pass further
func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
if setupKey == "" && userID == "" {
// no auth method provided => reject access
return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
@ -433,7 +435,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
// and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry.
_, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key)
_, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peer.Key)
if err == nil {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
}
@ -446,12 +448,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
var newPeer *nbpeer.Peer
var groupsToAdd []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var setupKeyID string
var setupKeyName string
var ephemeral bool
if addedByUser {
user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID)
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, userID)
if err != nil {
return fmt.Errorf("failed to get user groups: %w", err)
}
@ -460,7 +462,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
opEvent.Activity = activity.PeerAddedByUser
} else {
// Validate the setup key
sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, encodedHashedKey)
sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey)
if err != nil {
return fmt.Errorf("failed to get setup key: %w", err)
}
@ -533,7 +535,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
}
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("failed to get account settings: %w", err)
}
@ -558,7 +560,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return fmt.Errorf("failed to add peer to account: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
@ -627,18 +629,18 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return newPeer, networkMap, postureChecks, nil
}
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID)
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, s store.Store, accountID string) (net.IP, error) {
takenIps, err := s.GetTakenIPs(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
}
network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
network, err := s.GetAccountNetwork(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {
return nil, fmt.Errorf("failed getting network: %w", err)
}
nextIp, err := AllocatePeerIP(network.Net, takenIps)
nextIp, err := types.AllocatePeerIP(network.Net, takenIps)
if err != nil {
return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
}
@ -647,7 +649,7 @@ func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, acc
}
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
if err != nil {
return nil, nil, nil, status.NewPeerNotRegisteredError()
@ -695,7 +697,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
}
if peerNotValid {
emptyMap := &NetworkMap{
emptyMap := &types.NetworkMap{
Network: account.Network.Copy(),
}
return peer, emptyMap, []*posture.Checks{}, nil
@ -710,7 +712,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
}
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login PeerLogin, err error) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
// we couldn't find this peer by its public key which can mean that peer hasn't been registered yet.
// Try registering it.
@ -730,7 +732,7 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo
// LoginPeer logs in or registers a peer.
// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so.
func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey)
if err != nil {
return am.handlePeerLoginNotFound(ctx, login, err)
@ -755,12 +757,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
}()
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil {
return nil, nil, nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
@ -785,7 +787,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
}
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
@ -849,7 +851,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey)
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, login.WireGuardPubKey)
if err != nil {
return err
}
@ -860,7 +862,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return err
}
@ -872,11 +874,11 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil
}
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *Account, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *types.Account, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
var postureChecks []*posture.Checks
if isRequiresApproval {
emptyMap := &NetworkMap{
emptyMap := &types.NetworkMap{
Network: account.Network.Copy(),
}
return peer, emptyMap, nil, nil
@ -896,7 +898,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
}
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error {
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *types.User, peer *nbpeer.Peer) error {
err := checkAuth(ctx, user.Id, peer)
if err != nil {
return err
@ -918,7 +920,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
return nil
}
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *User) error {
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *types.User) error {
if peer.AddedWithSSOLogin() {
if user.IsBlocked() {
return status.Errorf(status.PermissionDenied, "user is blocked")
@ -939,7 +941,7 @@ func checkAuth(ctx context.Context, loginUserID string, peer *nbpeer.Peer) error
return nil
}
func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings) bool {
func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Settings) bool {
expired, expiresIn := peer.LoginExpired(settings.PeerLoginExpiration)
expired = settings.PeerLoginExpirationEnabled && expired
if expired || peer.Status.LoginExpired {
@ -991,7 +993,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
}
for _, p := range userPeers {
aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap)
aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap)
for _, aclPeer := range aclPeers {
if aclPeer.ID == peerID {
return peer, nil
@ -1069,7 +1071,7 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) {
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *types.Account, peerID string) (bool, error) {
peerGroupIDs := make([]string, 0)
for _, group := range account.Groups {
if slices.Contains(group.Peers, peerID) {

View File

@ -27,7 +27,9 @@ import (
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route"
)
@ -37,13 +39,13 @@ func TestPeer_LoginExpired(t *testing.T) {
expirationEnabled bool
lastLogin time.Time
expected bool
accountSettings *Settings
accountSettings *types.Settings
}{
{
name: "Peer Login Expiration Disabled. Peer Login Should Not Expire",
expirationEnabled: false,
lastLogin: time.Now().UTC().Add(-25 * time.Hour),
accountSettings: &Settings{
accountSettings: &types.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
@ -53,7 +55,7 @@ func TestPeer_LoginExpired(t *testing.T) {
name: "Peer Login Should Expire",
expirationEnabled: true,
lastLogin: time.Now().UTC().Add(-25 * time.Hour),
accountSettings: &Settings{
accountSettings: &types.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
@ -63,7 +65,7 @@ func TestPeer_LoginExpired(t *testing.T) {
name: "Peer Login Should Not Expire",
expirationEnabled: true,
lastLogin: time.Now().UTC(),
accountSettings: &Settings{
accountSettings: &types.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
@ -92,14 +94,14 @@ func TestPeer_SessionExpired(t *testing.T) {
lastLogin time.Time
connected bool
expected bool
accountSettings *Settings
accountSettings *types.Settings
}{
{
name: "Peer Inactivity Expiration Disabled. Peer Inactivity Should Not Expire",
expirationEnabled: false,
connected: false,
lastLogin: time.Now().UTC().Add(-1 * time.Second),
accountSettings: &Settings{
accountSettings: &types.Settings{
PeerInactivityExpirationEnabled: true,
PeerInactivityExpiration: time.Hour,
},
@ -110,7 +112,7 @@ func TestPeer_SessionExpired(t *testing.T) {
expirationEnabled: true,
connected: false,
lastLogin: time.Now().UTC().Add(-1 * time.Second),
accountSettings: &Settings{
accountSettings: &types.Settings{
PeerInactivityExpirationEnabled: true,
PeerInactivityExpiration: time.Second,
},
@ -121,7 +123,7 @@ func TestPeer_SessionExpired(t *testing.T) {
expirationEnabled: true,
connected: true,
lastLogin: time.Now().UTC(),
accountSettings: &Settings{
accountSettings: &types.Settings{
PeerInactivityExpirationEnabled: true,
PeerInactivityExpiration: time.Second,
},
@ -161,7 +163,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
t.Fatal(err)
}
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false)
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userId, false)
if err != nil {
t.Fatal("error creating setup key")
return
@ -233,9 +235,9 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
t.Fatal(err)
}
var setupKey *SetupKey
var setupKey *types.SetupKey
for _, key := range account.SetupKeys {
if key.Type == SetupKeyReusable {
if key.Type == types.SetupKeyReusable {
setupKey = key
}
}
@ -303,16 +305,16 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
return
}
policy := &Policy{
policy := &types.Policy{
Name: "test",
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{group1.ID},
Destinations: []string{group2.ID},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
Action: types.PolicyTrafficActionAccept,
},
},
}
@ -410,7 +412,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
t.Fatal(err)
}
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false)
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userId, false)
if err != nil {
t.Fatal("error creating setup key")
return
@ -469,9 +471,9 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
adminUser := "account_creator"
someUser := "some_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "")
account.Users[someUser] = &User{
account.Users[someUser] = &types.User{
Id: someUser,
Role: UserRoleUser,
Role: types.UserRoleUser,
}
account.Settings.RegularUsersViewBlocked = false
@ -482,7 +484,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
}
// two peers one added by a regular user and one with a setup key
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false)
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, adminUser, false)
if err != nil {
t.Fatal("error creating setup key")
return
@ -567,77 +569,77 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
func TestDefaultAccountManager_GetPeers(t *testing.T) {
testCases := []struct {
name string
role UserRole
role types.UserRole
limitedViewSettings bool
isServiceUser bool
expectedPeerCount int
}{
{
name: "Regular user, no limited view settings, not a service user",
role: UserRoleUser,
role: types.UserRoleUser,
limitedViewSettings: false,
isServiceUser: false,
expectedPeerCount: 1,
},
{
name: "Service user, no limited view settings",
role: UserRoleUser,
role: types.UserRoleUser,
limitedViewSettings: false,
isServiceUser: true,
expectedPeerCount: 2,
},
{
name: "Regular user, limited view settings",
role: UserRoleUser,
role: types.UserRoleUser,
limitedViewSettings: true,
isServiceUser: false,
expectedPeerCount: 0,
},
{
name: "Service user, limited view settings",
role: UserRoleUser,
role: types.UserRoleUser,
limitedViewSettings: true,
isServiceUser: true,
expectedPeerCount: 2,
},
{
name: "Admin, no limited view settings, not a service user",
role: UserRoleAdmin,
role: types.UserRoleAdmin,
limitedViewSettings: false,
isServiceUser: false,
expectedPeerCount: 2,
},
{
name: "Admin service user, no limited view settings",
role: UserRoleAdmin,
role: types.UserRoleAdmin,
limitedViewSettings: false,
isServiceUser: true,
expectedPeerCount: 2,
},
{
name: "Admin, limited view settings",
role: UserRoleAdmin,
role: types.UserRoleAdmin,
limitedViewSettings: true,
isServiceUser: false,
expectedPeerCount: 2,
},
{
name: "Admin Service user, limited view settings",
role: UserRoleAdmin,
role: types.UserRoleAdmin,
limitedViewSettings: true,
isServiceUser: true,
expectedPeerCount: 2,
},
{
name: "Owner, no limited view settings",
role: UserRoleOwner,
role: types.UserRoleOwner,
limitedViewSettings: true,
isServiceUser: false,
expectedPeerCount: 2,
},
{
name: "Owner, limited view settings",
role: UserRoleOwner,
role: types.UserRoleOwner,
limitedViewSettings: true,
isServiceUser: false,
expectedPeerCount: 2,
@ -656,12 +658,12 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
adminUser := "account_creator"
someUser := "some_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "")
account.Users[someUser] = &User{
account.Users[someUser] = &types.User{
Id: someUser,
Role: testCase.role,
IsServiceUser: testCase.isServiceUser,
}
account.Policies = []*Policy{}
account.Policies = []*types.Policy{}
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
err = manager.Store.SaveAccount(context.Background(), account)
@ -726,9 +728,9 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
regularUser := "regular_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "")
account.Users[regularUser] = &User{
account.Users[regularUser] = &types.User{
Id: regularUser,
Role: UserRoleUser,
Role: types.UserRoleUser,
}
// Create peers
@ -746,7 +748,7 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
}
// Create groups and policies
account.Policies = make([]*Policy, 0, groups)
account.Policies = make([]*types.Policy, 0, groups)
for i := 0; i < groups; i++ {
groupID := fmt.Sprintf("group-%d", i)
group := &nbgroup.Group{
@ -760,11 +762,11 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
account.Groups[groupID] = group
// Create a policy for this group
policy := &Policy{
policy := &types.Policy{
ID: fmt.Sprintf("policy-%d", i),
Name: fmt.Sprintf("Policy for Group %d", i),
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
ID: fmt.Sprintf("rule-%d", i),
Name: fmt.Sprintf("Rule for Group %d", i),
@ -772,8 +774,8 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou
Sources: []string{groupID},
Destinations: []string{groupID},
Bidirectional: true,
Protocol: PolicyRuleProtocolALL,
Action: PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolALL,
Action: types.PolicyTrafficActionAccept,
},
},
}
@ -939,8 +941,8 @@ func TestToSyncResponse(t *testing.T) {
Payload: "turn-user",
Signature: "turn-pass",
}
networkMap := &NetworkMap{
Network: &Network{Net: *ipnet, Serial: 1000},
networkMap := &types.NetworkMap{
Network: &types.Network{Net: *ipnet, Serial: 1000},
Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}},
OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}},
Routes: []*nbroute.Route{
@ -987,8 +989,8 @@ func TestToSyncResponse(t *testing.T) {
},
CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}},
},
FirewallRules: []*FirewallRule{
{PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"},
FirewallRules: []*types.FirewallRule{
{PeerIP: "192.168.1.2", Direction: types.FirewallRuleDirectionIN, Action: string(types.PolicyTrafficActionAccept), Protocol: string(types.PolicyRuleProtocolTCP), Port: "80"},
},
}
dnsName := "example.com"
@ -1088,7 +1090,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
@ -1099,13 +1101,13 @@ func Test_RegisterPeerByUser(t *testing.T) {
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003"
_, err = store.GetAccount(context.Background(), existingAccountID)
_, err = s.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
@ -1128,12 +1130,12 @@ func Test_RegisterPeerByUser(t *testing.T) {
addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer)
require.NoError(t, err)
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key)
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, addedPeer.Key)
require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.UserID, existingUserID)
account, err := store.GetAccount(context.Background(), existingAccountID)
account, err := s.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.Contains(t, account.Peers, addedPeer.ID)
assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname)
@ -1152,7 +1154,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
@ -1163,13 +1165,13 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
_, err = store.GetAccount(context.Background(), existingAccountID)
_, err = s.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
@ -1192,11 +1194,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
require.NoError(t, err)
peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key)
require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID)
account, err := store.GetAccount(context.Background(), existingAccountID)
account, err := s.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.Contains(t, account.Peers, addedPeer.ID)
assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID)
@ -1219,7 +1221,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
@ -1230,13 +1232,13 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
assert.NoError(t, err)
am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC"
_, err = store.GetAccount(context.Background(), existingAccountID)
_, err = s.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
newPeer := &nbpeer.Peer{
@ -1258,10 +1260,10 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
require.Error(t, err)
_, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key)
require.Error(t, err)
account, err := store.GetAccount(context.Background(), existingAccountID)
account, err := s.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
assert.NotContains(t, account.Peers, newPeer.ID)
assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID)
@ -1304,26 +1306,26 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
require.NoError(t, err)
// create a user with auto groups
_, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*User{
_, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*types.User{
{
Id: "regularUser1",
AccountID: account.Id,
Role: UserRoleAdmin,
Issued: UserIssuedAPI,
Role: types.UserRoleAdmin,
Issued: types.UserIssuedAPI,
AutoGroups: []string{"groupA"},
},
{
Id: "regularUser2",
AccountID: account.Id,
Role: UserRoleAdmin,
Issued: UserIssuedAPI,
Role: types.UserRoleAdmin,
Issued: types.UserIssuedAPI,
AutoGroups: []string{"groupB"},
},
{
Id: "regularUser3",
AccountID: account.Id,
Role: UserRoleAdmin,
Issued: UserIssuedAPI,
Role: types.UserRoleAdmin,
Issued: types.UserIssuedAPI,
AutoGroups: []string{"groupC"},
},
}, true)
@ -1464,15 +1466,15 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
// Adding peer to group linked with policy should update account peers and send peer update
t.Run("adding peer to group linked with policy", func(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
Action: types.PolicyTrafficActionAccept,
},
},
})

View File

@ -3,344 +3,22 @@ package server
import (
"context"
_ "embed"
"strconv"
"strings"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/proto"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
)
// PolicyUpdateOperationType operation type
type PolicyUpdateOperationType int
// PolicyTrafficActionType action type for the firewall
type PolicyTrafficActionType string
// PolicyRuleProtocolType type of traffic
type PolicyRuleProtocolType string
// PolicyRuleDirection direction of traffic
type PolicyRuleDirection string
const (
// PolicyTrafficActionAccept indicates that the traffic is accepted
PolicyTrafficActionAccept = PolicyTrafficActionType("accept")
// PolicyTrafficActionDrop indicates that the traffic is dropped
PolicyTrafficActionDrop = PolicyTrafficActionType("drop")
)
const (
// PolicyRuleProtocolALL type of traffic
PolicyRuleProtocolALL = PolicyRuleProtocolType("all")
// PolicyRuleProtocolTCP type of traffic
PolicyRuleProtocolTCP = PolicyRuleProtocolType("tcp")
// PolicyRuleProtocolUDP type of traffic
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
// PolicyRuleProtocolICMP type of traffic
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
)
const (
// PolicyRuleFlowDirect allows traffic from source to destination
PolicyRuleFlowDirect = PolicyRuleDirection("direct")
// PolicyRuleFlowBidirect allows traffic to both directions
PolicyRuleFlowBidirect = PolicyRuleDirection("bidirect")
)
const (
// DefaultRuleName is a name for the Default rule that is created for every account
DefaultRuleName = "Default"
// DefaultRuleDescription is a description for the Default rule that is created for every account
DefaultRuleDescription = "This is a default rule that allows connections between all the resources"
// DefaultPolicyName is a name for the Default policy that is created for every account
DefaultPolicyName = "Default"
// DefaultPolicyDescription is a description for the Default policy that is created for every account
DefaultPolicyDescription = "This is a default policy that allows connections between all the resources"
)
const (
firewallRuleDirectionIN = 0
firewallRuleDirectionOUT = 1
)
// PolicyUpdateOperation operation object with type and values to be applied
type PolicyUpdateOperation struct {
Type PolicyUpdateOperationType
Values []string
}
// RulePortRange represents a range of ports for a firewall rule.
type RulePortRange struct {
Start uint16
End uint16
}
// PolicyRule is the metadata of the policy
type PolicyRule struct {
// ID of the policy rule
ID string `gorm:"primaryKey"`
// PolicyID is a reference to Policy that this object belongs
PolicyID string `json:"-" gorm:"index"`
// Name of the rule visible in the UI
Name string
// Description of the rule visible in the UI
Description string
// Enabled status of rule in the system
Enabled bool
// Action policy accept or drops packets
Action PolicyTrafficActionType
// Destinations policy destination groups
Destinations []string `gorm:"serializer:json"`
// Sources policy source groups
Sources []string `gorm:"serializer:json"`
// Bidirectional define if the rule is applicable in both directions, sources, and destinations
Bidirectional bool
// Protocol type of the traffic
Protocol PolicyRuleProtocolType
// Ports or it ranges list
Ports []string `gorm:"serializer:json"`
// PortRanges a list of port ranges.
PortRanges []RulePortRange `gorm:"serializer:json"`
}
// Copy returns a copy of a policy rule
func (pm *PolicyRule) Copy() *PolicyRule {
rule := &PolicyRule{
ID: pm.ID,
PolicyID: pm.PolicyID,
Name: pm.Name,
Description: pm.Description,
Enabled: pm.Enabled,
Action: pm.Action,
Destinations: make([]string, len(pm.Destinations)),
Sources: make([]string, len(pm.Sources)),
Bidirectional: pm.Bidirectional,
Protocol: pm.Protocol,
Ports: make([]string, len(pm.Ports)),
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
}
copy(rule.Destinations, pm.Destinations)
copy(rule.Sources, pm.Sources)
copy(rule.Ports, pm.Ports)
copy(rule.PortRanges, pm.PortRanges)
return rule
}
// Policy of the Rego query
type Policy struct {
// ID of the policy'
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
// Name of the Policy
Name string
// Description of the policy visible in the UI
Description string
// Enabled status of the policy
Enabled bool
// Rules of the policy
Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id;constraint:OnDelete:CASCADE;"`
// SourcePostureChecks are ID references to Posture checks for policy source groups
SourcePostureChecks []string `gorm:"serializer:json"`
}
// Copy returns a copy of the policy.
func (p *Policy) Copy() *Policy {
c := &Policy{
ID: p.ID,
AccountID: p.AccountID,
Name: p.Name,
Description: p.Description,
Enabled: p.Enabled,
Rules: make([]*PolicyRule, len(p.Rules)),
SourcePostureChecks: make([]string, len(p.SourcePostureChecks)),
}
for i, r := range p.Rules {
c.Rules[i] = r.Copy()
}
copy(c.SourcePostureChecks, p.SourcePostureChecks)
return c
}
// EventMeta returns activity event meta related to this policy
func (p *Policy) EventMeta() map[string]any {
return map[string]any{"name": p.Name}
}
// UpgradeAndFix different version of policies to latest version
func (p *Policy) UpgradeAndFix() {
for _, r := range p.Rules {
// start migrate from version v0.20.3
if r.Protocol == "" {
r.Protocol = PolicyRuleProtocolALL
}
if r.Protocol == PolicyRuleProtocolALL && !r.Bidirectional {
r.Bidirectional = true
}
// -- v0.20.4
}
}
// ruleGroups returns a list of all groups referenced in the policy's rules,
// including sources and destinations.
func (p *Policy) ruleGroups() []string {
groups := make([]string, 0)
for _, rule := range p.Rules {
groups = append(groups, rule.Sources...)
groups = append(groups, rule.Destinations...)
}
return groups
}
// FirewallRule is a rule of the firewall.
type FirewallRule struct {
// PeerIP of the peer
PeerIP string
// Direction of the traffic
Direction int
// Action of the traffic
Action string
// Protocol of the traffic
Protocol string
// Port of the traffic
Port string
}
// getPeerConnectionResources for a given peer
//
// This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
for _, policy := range a.Policies {
if !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
if rule.Bidirectional {
if peerInSources {
generateResources(rule, destinationPeers, firewallRuleDirectionIN)
}
if peerInDestinations {
generateResources(rule, sourcePeers, firewallRuleDirectionOUT)
}
}
if peerInSources {
generateResources(rule, destinationPeers, firewallRuleDirectionOUT)
}
if peerInDestinations {
generateResources(rule, sourcePeers, firewallRuleDirectionIN)
}
}
}
return getAccumulatedResources()
}
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
//
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be
// generated. The accumulator function returns the result of all the generator calls.
func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
rulesExists := make(map[string]struct{})
peersExists := make(map[string]struct{})
rules := make([]*FirewallRule, 0)
peers := make([]*nbpeer.Peer, 0)
all, err := a.GetGroupAll()
if err != nil {
log.WithContext(ctx).Errorf("failed to get group all: %v", err)
all = &nbgroup.Group{}
}
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
isAll := (len(all.Peers) - 1) == len(groupPeers)
for _, peer := range groupPeers {
if peer == nil {
continue
}
if _, ok := peersExists[peer.ID]; !ok {
peers = append(peers, peer)
peersExists[peer.ID] = struct{}{}
}
fr := FirewallRule{
PeerIP: peer.IP.String(),
Direction: direction,
Action: string(rule.Action),
Protocol: string(rule.Protocol),
}
if isAll {
fr.PeerIP = "0.0.0.0"
}
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
if _, ok := rulesExists[ruleID]; ok {
continue
}
rulesExists[ruleID] = struct{}{}
if len(rule.Ports) == 0 {
rules = append(rules, &fr)
continue
}
for _, port := range rule.Ports {
pr := fr // clone rule and add set new port
pr.Port = port
rules = append(rules, &pr)
}
}
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
return peers, rules
}
}
// GetPolicy from the store
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
@ -353,15 +31,15 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
return nil, status.NewAdminPermissionError()
}
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
return am.Store.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID)
}
// SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) {
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
@ -378,7 +56,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
var updateAccountPeers bool
var action = activity.PolicyAdded
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
return err
}
@ -388,7 +66,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
@ -398,7 +76,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
saveFunc = transaction.SavePolicy
}
return saveFunc(ctx, LockingStrengthUpdate, policy)
return saveFunc(ctx, store.LockingStrengthUpdate, policy)
})
if err != nil {
return nil, err
@ -418,7 +96,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return err
}
@ -431,11 +109,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
return status.NewAdminPermissionError()
}
var policy *Policy
var policy *types.Policy
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID)
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
policy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID)
if err != nil {
return err
}
@ -445,11 +123,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID)
return transaction.DeletePolicy(ctx, store.LockingStrengthUpdate, accountID, policyID)
})
if err != nil {
return err
@ -465,8 +143,8 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
}
// ListPolicies from the store.
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
@ -479,13 +157,13 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
}
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) {
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
if isUpdate {
existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID)
if err != nil {
return false, err
}
@ -494,7 +172,7 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, account
return false, nil
}
hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups())
hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
if err != nil {
return false, err
}
@ -504,13 +182,13 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, account
}
}
return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups())
return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.RuleGroups())
}
// validatePolicy validates the policy and its rules.
func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error {
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
if policy.ID != "" {
_, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID)
if err != nil {
return err
}
@ -519,12 +197,12 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po
policy.AccountID = accountID
}
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups())
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, policy.RuleGroups())
if err != nil {
return err
}
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks)
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, policy.SourcePostureChecks)
if err != nil {
return err
}
@ -548,84 +226,6 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po
return nil
}
// getAllPeersFromGroups for given peer ID and list of groups
//
// Returns a list of peers from specified groups that pass specified posture checks
// and a boolean indicating if the supplied peer ID exists within these groups.
//
// Important: Posture checks are applicable only to source group peers,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
peerInGroups := false
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
for _, g := range groups {
group, ok := a.Groups[g]
if !ok {
continue
}
for _, p := range group.Peers {
peer, ok := a.Peers[p]
if !ok || peer == nil {
continue
}
// validate the peer based on policy posture checks applied
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid {
continue
}
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeers = append(filteredPeers, peer)
}
}
return filteredPeers, peerInGroups
}
// validatePostureChecksOnPeer validates the posture checks on a peer
func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool {
peer, ok := a.Peers[peerID]
if !ok && peer == nil {
return false
}
for _, postureChecksID := range sourcePostureChecksID {
postureChecks := a.getPostureChecks(postureChecksID)
if postureChecks == nil {
continue
}
for _, check := range postureChecks.GetChecks() {
isValid, err := check.Check(ctx, *peer)
if err != nil {
log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error())
}
if !isValid {
return false
}
}
}
return true
}
func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
for _, postureChecks := range a.PostureChecks {
if postureChecks.ID == postureChecksID {
return postureChecks
}
}
return nil
}
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string {
validIDs := make([]string, 0, len(postureChecksIds))
@ -651,7 +251,7 @@ func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []str
}
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, len(rules))
for i := range rules {
rule := rules[i]

View File

@ -13,10 +13,11 @@ import (
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
)
func TestAccount_getPeersByPolicy(t *testing.T) {
account := &Account{
account := &types.Account{
Peers: map[string]*nbpeer.Peer{
"peerA": {
ID: "peerA",
@ -87,21 +88,21 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
},
},
},
Policies: []*Policy{
Policies: []*types.Policy{
{
ID: "RuleDefault",
Name: "Default",
Description: "This is a default rule that allows connections between all the resources",
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
ID: "RuleDefault",
Name: "Default",
Description: "This is a default rule that allows connections between all the resources",
Bidirectional: true,
Enabled: true,
Protocol: PolicyRuleProtocolALL,
Action: PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolALL,
Action: types.PolicyTrafficActionAccept,
Sources: []string{
"GroupAll",
},
@ -116,15 +117,15 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
Name: "Swarm",
Description: "No description",
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
ID: "RuleSwarm",
Name: "Swarm",
Description: "No description",
Bidirectional: true,
Enabled: true,
Protocol: PolicyRuleProtocolALL,
Action: PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolALL,
Action: types.PolicyTrafficActionAccept,
Sources: []string{
"GroupSwarm",
"GroupAll",
@ -145,14 +146,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
t.Run("check that all peers get map", func(t *testing.T) {
for _, p := range account.Peers {
peers, firewallRules := account.getPeerConnectionResources(context.Background(), p.ID, validatedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers)
assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present")
assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present")
}
})
t.Run("check first peer map details", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", validatedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers)
assert.Len(t, peers, 7)
assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"])
@ -160,45 +161,45 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
assert.Contains(t, peers, account.Peers["peerE"])
assert.Contains(t, peers, account.Peers["peerF"])
epectedFirewallRules := []*FirewallRule{
epectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "0.0.0.0",
Direction: firewallRuleDirectionIN,
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
},
{
PeerIP: "0.0.0.0",
Direction: firewallRuleDirectionOUT,
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
},
{
PeerIP: "100.65.14.88",
Direction: firewallRuleDirectionIN,
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
},
{
PeerIP: "100.65.14.88",
Direction: firewallRuleDirectionOUT,
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
},
{
PeerIP: "100.65.254.139",
Direction: firewallRuleDirectionOUT,
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
},
{
PeerIP: "100.65.254.139",
Direction: firewallRuleDirectionIN,
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
@ -206,14 +207,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
{
PeerIP: "100.65.62.5",
Direction: firewallRuleDirectionOUT,
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
},
{
PeerIP: "100.65.62.5",
Direction: firewallRuleDirectionIN,
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
@ -221,14 +222,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
{
PeerIP: "100.65.32.206",
Direction: firewallRuleDirectionOUT,
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
},
{
PeerIP: "100.65.32.206",
Direction: firewallRuleDirectionIN,
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
@ -236,14 +237,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
{
PeerIP: "100.65.250.202",
Direction: firewallRuleDirectionOUT,
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
},
{
PeerIP: "100.65.250.202",
Direction: firewallRuleDirectionIN,
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
@ -251,14 +252,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
{
PeerIP: "100.65.13.186",
Direction: firewallRuleDirectionOUT,
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
},
{
PeerIP: "100.65.13.186",
Direction: firewallRuleDirectionIN,
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
@ -266,14 +267,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
{
PeerIP: "100.65.29.55",
Direction: firewallRuleDirectionOUT,
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
},
{
PeerIP: "100.65.29.55",
Direction: firewallRuleDirectionIN,
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
@ -289,7 +290,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
}
func TestAccount_getPeersByPolicyDirect(t *testing.T) {
account := &Account{
account := &types.Account{
Peers: map[string]*nbpeer.Peer{
"peerA": {
ID: "peerA",
@ -332,21 +333,21 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
},
},
},
Policies: []*Policy{
Policies: []*types.Policy{
{
ID: "RuleDefault",
Name: "Default",
Description: "This is a default rule that allows connections between all the resources",
Enabled: false,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
ID: "RuleDefault",
Name: "Default",
Description: "This is a default rule that allows connections between all the resources",
Bidirectional: true,
Enabled: false,
Protocol: PolicyRuleProtocolALL,
Action: PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolALL,
Action: types.PolicyTrafficActionAccept,
Sources: []string{
"GroupAll",
},
@ -361,15 +362,15 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
Name: "Swarm",
Description: "No description",
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
ID: "RuleSwarm",
Name: "Swarm",
Description: "No description",
Bidirectional: true,
Enabled: true,
Protocol: PolicyRuleProtocolALL,
Action: PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolALL,
Action: types.PolicyTrafficActionAccept,
Sources: []string{
"GroupSwarm",
},
@ -388,20 +389,20 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}
t.Run("check first peer map", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*FirewallRule{
epectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "100.65.254.139",
Direction: firewallRuleDirectionIN,
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
},
{
PeerIP: "100.65.254.139",
Direction: firewallRuleDirectionOUT,
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
@ -416,20 +417,20 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
})
t.Run("check second peer map", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*FirewallRule{
epectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "100.65.80.39",
Direction: firewallRuleDirectionIN,
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
},
{
PeerIP: "100.65.80.39",
Direction: firewallRuleDirectionOUT,
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
@ -446,13 +447,13 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
account.Policies[1].Rules[0].Bidirectional = false
t.Run("check first peer map directional only", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*FirewallRule{
epectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "100.65.254.139",
Direction: firewallRuleDirectionOUT,
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
@ -467,13 +468,13 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
})
t.Run("check second peer map directional only", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*FirewallRule{
epectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "100.65.80.39",
Direction: firewallRuleDirectionIN,
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
@ -489,7 +490,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}
func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
account := &Account{
account := &types.Account{
Peers: map[string]*nbpeer.Peer{
"peerA": {
ID: "peerA",
@ -630,17 +631,17 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
},
}
account.Policies = append(account.Policies, &Policy{
account.Policies = append(account.Policies, &types.Policy{
ID: "PolicyPostureChecks",
Name: "",
Description: "This is the policy with posture checks applied",
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
ID: "RuleSwarm",
Name: "Swarm",
Enabled: true,
Action: PolicyTrafficActionAccept,
Action: types.PolicyTrafficActionAccept,
Destinations: []string{
"GroupSwarm",
},
@ -648,7 +649,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
"GroupAll",
},
Bidirectional: false,
Protocol: PolicyRuleProtocolTCP,
Protocol: types.PolicyRuleProtocolTCP,
Ports: []string{"80"},
},
},
@ -664,7 +665,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
// will establish a connection with all source peers satisfying the NB posture check.
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@ -674,13 +675,13 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 1)
expectedFirewallRules := []*FirewallRule{
expectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "0.0.0.0",
Direction: firewallRuleDirectionOUT,
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
@ -690,7 +691,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers)
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@ -700,7 +701,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers)
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@ -715,19 +716,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers)
assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0)
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers)
assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0)
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
@ -742,14 +743,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers)
assert.Len(t, peers, 3)
assert.Len(t, firewallRules, 3)
assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"])
assert.Contains(t, peers, account.Peers["peerD"])
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerA", approvedPeers)
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers)
assert.Len(t, peers, 5)
// assert peers from Group Swarm
assert.Contains(t, peers, account.Peers["peerD"])
@ -760,45 +761,45 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// assert peers from Group All
assert.Contains(t, peers, account.Peers["peerC"])
expectedFirewallRules := []*FirewallRule{
expectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "100.65.62.5",
Direction: firewallRuleDirectionOUT,
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
},
{
PeerIP: "100.65.32.206",
Direction: firewallRuleDirectionOUT,
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
},
{
PeerIP: "100.65.13.186",
Direction: firewallRuleDirectionOUT,
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
},
{
PeerIP: "100.65.29.55",
Direction: firewallRuleDirectionOUT,
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
},
{
PeerIP: "100.65.254.139",
Direction: firewallRuleDirectionIN,
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "tcp",
Port: "80",
},
{
PeerIP: "100.65.62.5",
Direction: firewallRuleDirectionIN,
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "tcp",
Port: "80",
@ -809,8 +810,8 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
})
}
func sortFunc() func(a *FirewallRule, b *FirewallRule) int {
return func(a, b *FirewallRule) int {
func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int {
return func(a, b *types.FirewallRule) int {
// Concatenate PeerIP and Direction as string for comparison
aStr := a.PeerIP + fmt.Sprintf("%d", a.Direction)
bStr := b.PeerIP + fmt.Sprintf("%d", b.Direction)
@ -858,9 +859,9 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
})
var policyWithGroupRulesNoPeers *Policy
var policyWithDestinationPeersOnly *Policy
var policyWithSourceAndDestinationPeers *Policy
var policyWithGroupRulesNoPeers *types.Policy
var policyWithDestinationPeersOnly *types.Policy
var policyWithSourceAndDestinationPeers *types.Policy
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
@ -870,16 +871,16 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
close(done)
}()
policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"groupB"},
Destinations: []string{"groupC"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
Action: types.PolicyTrafficActionAccept,
},
},
})
@ -901,17 +902,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
close(done)
}()
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupB"},
Protocol: PolicyRuleProtocolTCP,
Protocol: types.PolicyRuleProtocolTCP,
Bidirectional: true,
Action: PolicyTrafficActionAccept,
Action: types.PolicyTrafficActionAccept,
},
},
})
@ -933,17 +934,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
close(done)
}()
policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"groupC"},
Destinations: []string{"groupD"},
Bidirectional: true,
Protocol: PolicyRuleProtocolTCP,
Action: PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolTCP,
Action: types.PolicyTrafficActionAccept,
},
},
})
@ -965,16 +966,16 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
close(done)
}()
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
AccountID: account.Id,
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupD"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
Action: types.PolicyTrafficActionAccept,
},
},
})

View File

@ -12,10 +12,12 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
@ -28,7 +30,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
return nil, status.NewAdminPermissionError()
}
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID)
}
// SavePostureChecks saves a posture check.
@ -36,7 +38,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
@ -53,7 +55,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
var isUpdate = postureChecks.ID != ""
var action = activity.PostureCheckCreated
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil {
return err
}
@ -64,7 +66,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
@ -72,7 +74,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
}
postureChecks.AccountID = accountID
return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks)
return transaction.SavePostureChecks(ctx, store.LockingStrengthUpdate, postureChecks)
})
if err != nil {
return nil, err
@ -92,7 +94,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return err
}
@ -107,8 +109,8 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
var postureChecks *posture.Checks
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID)
if err != nil {
return err
}
@ -117,11 +119,11 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
return err
}
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID)
return transaction.DeletePostureChecks(ctx, store.LockingStrengthUpdate, accountID, postureChecksID)
})
if err != nil {
return err
@ -134,7 +136,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
// ListPostureChecks returns a list of posture checks.
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
@ -147,11 +149,11 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID)
}
// getPeerPostureChecks returns the posture checks applied for a given peer.
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID string) ([]*posture.Checks, error) {
func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) {
peerPostureChecks := make(map[string]*posture.Checks)
if len(account.PostureChecks) == 0 {
@ -172,15 +174,15 @@ func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID s
}
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return false, err
}
for _, policy := range policies {
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups())
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.RuleGroups())
if err != nil {
return false, err
}
@ -195,21 +197,21 @@ func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, a
}
// validatePostureChecks validates the posture checks.
func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error {
func validatePostureChecks(ctx context.Context, transaction store.Store, accountID string, postureChecks *posture.Checks) error {
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
}
// If the posture check already has an ID, verify its existence in the store.
if postureChecks.ID != "" {
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil {
if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecks.ID); err != nil {
return err
}
return nil
}
// For new posture checks, ensure no duplicates by name.
checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return err
}
@ -226,7 +228,7 @@ func validatePostureChecks(ctx context.Context, transaction Store, accountID str
}
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error {
func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error {
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
if err != nil {
return err
@ -237,7 +239,7 @@ func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, pee
}
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
postureCheck := account.getPostureChecks(sourcePostureCheckID)
postureCheck := account.GetPostureChecks(sourcePostureCheckID)
if postureCheck == nil {
return errors.New("failed to add policy posture checks: posture checks not found")
}
@ -248,7 +250,7 @@ func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, pee
}
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy) (bool, error) {
func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
@ -270,8 +272,8 @@ func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy)
}
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error {
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return err
}

View File

@ -9,6 +9,8 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/posture"
)
@ -92,17 +94,17 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
})
}
func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) {
func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, error) {
accountID := "testingAccount"
domain := "example.com"
admin := &User{
admin := &types.User{
Id: adminUserID,
Role: UserRoleAdmin,
Role: types.UserRoleAdmin,
}
user := &User{
user := &types.User{
Id: regularUserID,
Role: UserRoleUser,
Role: types.UserRoleUser,
}
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
@ -209,15 +211,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}
})
policy := &Policy{
policy := &types.Policy{
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
Action: types.PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{postureCheckB.ID},
@ -312,15 +314,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
t.Run("updating linked posture check to policy with no peers", func(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"groupB"},
Destinations: []string{"groupC"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
Action: types.PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{postureCheckB.ID},
@ -356,15 +358,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
})
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"groupB"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
Action: types.PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{postureCheckB.ID},
@ -395,15 +397,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
// Updating linked client posture check to policy where source has peers but destination does not,
// should trigger account peers update and send peer update
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupB"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
Action: types.PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{postureCheckB.ID},
@ -454,7 +456,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
AccountID: account.Id,
Peers: []string{},
}
err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB})
err = manager.Store.SaveGroups(context.Background(), store.LockingStrengthUpdate, []*group.Group{groupA, groupB})
require.NoError(t, err, "failed to save groups")
postureCheckA := &posture.Checks{
@ -477,9 +479,9 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB)
require.NoError(t, err, "failed to save postureCheckB")
policy := &Policy{
policy := &types.Policy{
AccountID: account.Id,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
@ -534,7 +536,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
groupA.Peers = []string{}
err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA)
err = manager.Store.SaveGroup(context.Background(), store.LockingStrengthUpdate, groupA)
require.NoError(t, err, "failed to save groups")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)

View File

@ -4,15 +4,12 @@ import (
"context"
"fmt"
"net/netip"
"slices"
"strconv"
"strings"
"unicode/utf8"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
@ -21,33 +18,9 @@ import (
"github.com/netbirdio/netbird/route"
)
// RouteFirewallRule a firewall rule applicable for a routed network.
type RouteFirewallRule struct {
// SourceRanges IP ranges of the routing peers.
SourceRanges []string
// Action of the traffic when the rule is applicable
Action string
// Destination a network prefix for the routed traffic
Destination string
// Protocol of the traffic
Protocol string
// Port of the traffic
Port uint16
// PortRange represents the range of ports for a firewall rule
PortRange RulePortRange
// isDynamic indicates whether the rule is for DNS routing
IsDynamic bool
}
// GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
@ -56,11 +29,11 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
}
return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID)
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, string(routeID), accountID)
}
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *types.Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
// routes can have both peer and peer_groups
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains)
@ -364,7 +337,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
@ -373,7 +346,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
}
return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
}
func toProtocolRoute(route *route.Route) *proto.Route {
@ -404,244 +377,7 @@ func getPlaceholderIP() netip.Prefix {
return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
}
// getPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account.
func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes))
enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID)
for _, route := range enabledRoutes {
// If no access control groups are specified, accept all traffic.
if len(route.AccessControlGroups) == 0 {
defaultPermit := getDefaultPermit(route)
routesFirewallRules = append(routesFirewallRules, defaultPermit...)
continue
}
distributionPeers := a.getDistributionGroupsPeers(route)
for _, accessGroup := range route.AccessControlGroups {
policies := getAllRoutePoliciesFromGroups(a, []string{accessGroup})
rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers)
routesFirewallRules = append(routesFirewallRules, rules...)
}
}
return routesFirewallRules
}
func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule {
var fwRules []*RouteFirewallRule
for _, policy := range policies {
if !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap)
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, firewallRuleDirectionIN)
fwRules = append(fwRules, rules...)
}
}
return fwRules
}
func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer {
distPeersWithPolicy := make(map[string]struct{})
for _, id := range rule.Sources {
group := a.Groups[id]
if group == nil {
continue
}
for _, pID := range group.Peers {
if pID == peerID {
continue
}
_, distPeer := distributionPeers[pID]
_, valid := validatedPeersMap[pID]
if distPeer && valid {
distPeersWithPolicy[pID] = struct{}{}
}
}
}
distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy))
for pID := range distPeersWithPolicy {
peer := a.Peers[pID]
if peer == nil {
continue
}
distributionGroupPeers = append(distributionGroupPeers, peer)
}
return distributionGroupPeers
}
func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} {
distPeers := make(map[string]struct{})
for _, id := range route.Groups {
group := a.Groups[id]
if group == nil {
continue
}
for _, pID := range group.Peers {
distPeers[pID] = struct{}{}
}
}
return distPeers
}
func getDefaultPermit(route *route.Route) []*RouteFirewallRule {
var rules []*RouteFirewallRule
sources := []string{"0.0.0.0/0"}
if route.Network.Addr().Is6() {
sources = []string{"::/0"}
}
rule := RouteFirewallRule{
SourceRanges: sources,
Action: string(PolicyTrafficActionAccept),
Destination: route.Network.String(),
Protocol: string(PolicyRuleProtocolALL),
IsDynamic: route.IsDynamic(),
}
rules = append(rules, &rule)
// dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally
if route.IsDynamic() {
ruleV6 := rule
ruleV6.SourceRanges = []string{"::/0"}
rules = append(rules, &ruleV6)
}
return rules
}
// getAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups
// and returns a list of policies that have rules with destinations matching the specified groups.
func getAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy {
routePolicies := make([]*Policy, 0)
for _, groupID := range accessControlGroups {
group, ok := account.Groups[groupID]
if !ok {
continue
}
for _, policy := range account.Policies {
for _, rule := range policy.Rules {
exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool {
return groupID == group.ID
})
if exist {
routePolicies = append(routePolicies, policy)
continue
}
}
}
}
return routePolicies
}
// generateRouteFirewallRules generates a list of firewall rules for a given route.
func generateRouteFirewallRules(ctx context.Context, route *route.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule {
rulesExists := make(map[string]struct{})
rules := make([]*RouteFirewallRule, 0)
sourceRanges := make([]string, 0, len(groupPeers))
for _, peer := range groupPeers {
if peer == nil {
continue
}
sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP))
}
baseRule := RouteFirewallRule{
SourceRanges: sourceRanges,
Action: string(rule.Action),
Destination: route.Network.String(),
Protocol: string(rule.Protocol),
IsDynamic: route.IsDynamic(),
}
// generate rule for port range
if len(rule.Ports) == 0 {
rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...)
} else {
rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...)
}
// TODO: generate IPv6 rules for dynamic routes
return rules
}
// generateRuleIDBase generates the base rule ID for checking duplicates.
func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string {
return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(firewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action
}
// generateRulesForPeer generates rules for a given peer based on ports and port ranges.
func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule {
rules := make([]*RouteFirewallRule, 0)
ruleIDBase := generateRuleIDBase(rule, baseRule)
if len(rule.Ports) == 0 {
if len(rule.PortRanges) == 0 {
if _, ok := rulesExists[ruleIDBase]; !ok {
rulesExists[ruleIDBase] = struct{}{}
rules = append(rules, &baseRule)
}
} else {
for _, portRange := range rule.PortRanges {
ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End)
if _, ok := rulesExists[ruleID]; !ok {
rulesExists[ruleID] = struct{}{}
pr := baseRule
pr.PortRange = portRange
rules = append(rules, &pr)
}
}
}
return rules
}
return rules
}
// generateRulesWithPorts generates rules when specific ports are provided.
func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule {
rules := make([]*RouteFirewallRule, 0)
ruleIDBase := generateRuleIDBase(rule, baseRule)
for _, port := range rule.Ports {
ruleID := ruleIDBase + port
if _, ok := rulesExists[ruleID]; ok {
continue
}
rulesExists[ruleID] = struct{}{}
pr := baseRule
p, err := strconv.ParseUint(port, 10, 16)
if err != nil {
log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID)
continue
}
pr.Port = uint16(p)
rules = append(rules, &pr)
}
return rules
}
func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFirewallRule {
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
result := make([]*proto.RouteFirewallRule, len(rules))
for i := range rules {
rule := rules[i]
@ -660,7 +396,7 @@ func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFir
// getProtoDirection converts the direction to proto.RuleDirection.
func getProtoDirection(direction int) proto.RuleDirection {
if direction == firewallRuleDirectionOUT {
if direction == types.FirewallRuleDirectionOUT {
return proto.RuleDirection_OUT
}
return proto.RuleDirection_IN
@ -668,7 +404,7 @@ func getProtoDirection(direction int) proto.RuleDirection {
// getProtoAction converts the action to proto.RuleAction.
func getProtoAction(action string) proto.RuleAction {
if action == string(PolicyTrafficActionDrop) {
if action == string(types.PolicyTrafficActionDrop) {
return proto.RuleAction_DROP
}
return proto.RuleAction_ACCEPT
@ -676,14 +412,14 @@ func getProtoAction(action string) proto.RuleAction {
// getProtoProtocol converts the protocol to proto.RuleProtocol.
func getProtoProtocol(protocol string) proto.RuleProtocol {
switch PolicyRuleProtocolType(protocol) {
case PolicyRuleProtocolALL:
switch types.PolicyRuleProtocolType(protocol) {
case types.PolicyRuleProtocolALL:
return proto.RuleProtocol_ALL
case PolicyRuleProtocolTCP:
case types.PolicyRuleProtocolTCP:
return proto.RuleProtocol_TCP
case PolicyRuleProtocolUDP:
case types.PolicyRuleProtocolUDP:
return proto.RuleProtocol_UDP
case PolicyRuleProtocolICMP:
case types.PolicyRuleProtocolICMP:
return proto.RuleProtocol_ICMP
default:
return proto.RuleProtocol_UNKNOWN
@ -691,7 +427,7 @@ func getProtoProtocol(protocol string) proto.RuleProtocol {
}
// getProtoPortInfo converts the port info to proto.PortInfo.
func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
var portInfo proto.PortInfo
if rule.Port != 0 {
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
@ -708,6 +444,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
// isRouteChangeAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers
func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *types.Account, route *route.Route) bool {
return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
}

View File

@ -17,7 +17,9 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
@ -1092,7 +1094,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id)
groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, account.Id)
require.NoError(t, err)
var groupHA1, groupHA2 *nbgroup.Group
for _, group := range groups {
@ -1255,10 +1257,10 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
}
func createRouterStore(t *testing.T) (Store, error) {
func createRouterStore(t *testing.T) (store.Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir)
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir)
if err != nil {
return nil, err
}
@ -1267,7 +1269,7 @@ func createRouterStore(t *testing.T) (Store, error) {
return store, nil
}
func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, error) {
t.Helper()
accountID := "testingAcc"
@ -1279,8 +1281,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
return nil, err
}
ips := account.getTakenIPs()
peer1IP, err := AllocatePeerIP(account.Network.Net, ips)
ips := account.GetTakenIPs()
peer1IP, err := types.AllocatePeerIP(account.Network.Net, ips)
if err != nil {
return nil, err
}
@ -1306,8 +1308,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
}
account.Peers[peer1.ID] = peer1
ips = account.getTakenIPs()
peer2IP, err := AllocatePeerIP(account.Network.Net, ips)
ips = account.GetTakenIPs()
peer2IP, err := types.AllocatePeerIP(account.Network.Net, ips)
if err != nil {
return nil, err
}
@ -1333,8 +1335,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
}
account.Peers[peer2.ID] = peer2
ips = account.getTakenIPs()
peer3IP, err := AllocatePeerIP(account.Network.Net, ips)
ips = account.GetTakenIPs()
peer3IP, err := types.AllocatePeerIP(account.Network.Net, ips)
if err != nil {
return nil, err
}
@ -1360,8 +1362,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
}
account.Peers[peer3.ID] = peer3
ips = account.getTakenIPs()
peer4IP, err := AllocatePeerIP(account.Network.Net, ips)
ips = account.GetTakenIPs()
peer4IP, err := types.AllocatePeerIP(account.Network.Net, ips)
if err != nil {
return nil, err
}
@ -1387,8 +1389,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
}
account.Peers[peer4.ID] = peer4
ips = account.getTakenIPs()
peer5IP, err := AllocatePeerIP(account.Network.Net, ips)
ips = account.GetTakenIPs()
peer5IP, err := types.AllocatePeerIP(account.Network.Net, ips)
if err != nil {
return nil, err
}
@ -1491,7 +1493,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
peerKIp = "100.65.29.66"
)
account := &Account{
account := &types.Account{
Peers: map[string]*nbpeer.Peer{
"peerA": {
ID: "peerA",
@ -1685,19 +1687,19 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
AccessControlGroups: []string{"route4"},
},
},
Policies: []*Policy{
Policies: []*types.Policy{
{
ID: "RuleRoute1",
Name: "Route1",
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
ID: "RuleRoute1",
Name: "ruleRoute1",
Bidirectional: true,
Enabled: true,
Protocol: PolicyRuleProtocolALL,
Action: PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolALL,
Action: types.PolicyTrafficActionAccept,
Ports: []string{"80", "320"},
Sources: []string{
"dev",
@ -1712,15 +1714,15 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
ID: "RuleRoute2",
Name: "Route2",
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
ID: "RuleRoute2",
Name: "ruleRoute2",
Bidirectional: true,
Enabled: true,
Protocol: PolicyRuleProtocolTCP,
Action: PolicyTrafficActionAccept,
PortRanges: []RulePortRange{
Protocol: types.PolicyRuleProtocolTCP,
Action: types.PolicyTrafficActionAccept,
PortRanges: []types.RulePortRange{
{
Start: 80,
End: 350,
@ -1742,14 +1744,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
ID: "RuleRoute4",
Name: "RuleRoute4",
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
ID: "RuleRoute4",
Name: "RuleRoute4",
Bidirectional: true,
Enabled: true,
Protocol: PolicyRuleProtocolTCP,
Action: PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolTCP,
Action: types.PolicyTrafficActionAccept,
Ports: []string{"80"},
Sources: []string{
"restrictQA",
@ -1764,14 +1766,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
ID: "RuleRoute5",
Name: "RuleRoute5",
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
ID: "RuleRoute5",
Name: "RuleRoute5",
Bidirectional: true,
Enabled: true,
Protocol: PolicyRuleProtocolALL,
Action: PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolALL,
Action: types.PolicyTrafficActionAccept,
Sources: []string{
"unrestrictedQA",
},
@ -1791,28 +1793,28 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
t.Run("check applied policies for the route", func(t *testing.T) {
route1 := account.Routes["route1"]
policies := getAllRoutePoliciesFromGroups(account, route1.AccessControlGroups)
policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups)
assert.Len(t, policies, 1)
route2 := account.Routes["route2"]
policies = getAllRoutePoliciesFromGroups(account, route2.AccessControlGroups)
policies = types.GetAllRoutePoliciesFromGroups(account, route2.AccessControlGroups)
assert.Len(t, policies, 1)
route3 := account.Routes["route3"]
policies = getAllRoutePoliciesFromGroups(account, route3.AccessControlGroups)
policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups)
assert.Len(t, policies, 0)
})
t.Run("check peer routes firewall rules", func(t *testing.T) {
routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers)
routesFirewallRules := account.GetPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers)
assert.Len(t, routesFirewallRules, 4)
expectedRoutesFirewallRules := []*RouteFirewallRule{
expectedRoutesFirewallRules := []*types.RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(AllowedIPsFormat, peerCIp),
fmt.Sprintf(AllowedIPsFormat, peerHIp),
fmt.Sprintf(AllowedIPsFormat, peerBIp),
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
},
Action: "accept",
Destination: "192.168.0.0/16",
@ -1821,9 +1823,9 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
},
{
SourceRanges: []string{
fmt.Sprintf(AllowedIPsFormat, peerCIp),
fmt.Sprintf(AllowedIPsFormat, peerHIp),
fmt.Sprintf(AllowedIPsFormat, peerBIp),
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
},
Action: "accept",
Destination: "192.168.0.0/16",
@ -1831,10 +1833,10 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
Port: 320,
},
}
additionalFirewallRule := []*RouteFirewallRule{
additionalFirewallRule := []*types.RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(AllowedIPsFormat, peerJIp),
fmt.Sprintf(types.AllowedIPsFormat, peerJIp),
},
Action: "accept",
Destination: "192.168.10.0/16",
@ -1843,7 +1845,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
},
{
SourceRanges: []string{
fmt.Sprintf(AllowedIPsFormat, peerKIp),
fmt.Sprintf(types.AllowedIPsFormat, peerKIp),
},
Action: "accept",
Destination: "192.168.10.0/16",
@ -1854,21 +1856,21 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...)))
// peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
assert.Len(t, routesFirewallRules, 2)
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
// peerE is a single routing peer for route 2 and route 3
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers)
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers)
assert.Len(t, routesFirewallRules, 3)
expectedRoutesFirewallRules = []*RouteFirewallRule{
expectedRoutesFirewallRules = []*types.RouteFirewallRule{
{
SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
Action: "accept",
Destination: existingNetwork.String(),
Protocol: "tcp",
PortRange: RulePortRange{Start: 80, End: 350},
PortRange: types.RulePortRange{Start: 80, End: 350},
},
{
SourceRanges: []string{"0.0.0.0/0"},
@ -1888,14 +1890,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
// peerC is part of route1 distribution groups but should not receive the routes firewall rules
routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
assert.Len(t, routesFirewallRules, 0)
})
}
// orderList is a helper function to sort a list of strings
func orderRuleSourceRanges(ruleList []*RouteFirewallRule) []*RouteFirewallRule {
func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFirewallRule {
for _, rule := range ruleList {
sort.Strings(rule.SourceRanges)
}

View File

@ -2,34 +2,16 @@ package server
import (
"context"
"crypto/sha256"
b64 "encoding/base64"
"hash/fnv"
"slices"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/status"
)
const (
// SetupKeyReusable is a multi-use key (can be used for multiple machines)
SetupKeyReusable SetupKeyType = "reusable"
// SetupKeyOneOff is a single use key (can be used only once)
SetupKeyOneOff SetupKeyType = "one-off"
// DefaultSetupKeyDuration = 1 month
DefaultSetupKeyDuration = 24 * 30 * time.Hour
// DefaultSetupKeyName is a default name of the default setup key
DefaultSetupKeyName = "Default key"
// SetupKeyUnlimitedUsage indicates an unlimited usage of a setup key
SetupKeyUnlimitedUsage = 0
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
)
const (
@ -67,169 +49,14 @@ type SetupKeyUpdateOperation struct {
Values []string
}
// SetupKeyType is the type of setup key
type SetupKeyType string
// SetupKey represents a pre-authorized key used to register machines (peers)
type SetupKey struct {
Id string
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
Key string
KeySecret string
Name string
Type SetupKeyType
CreatedAt time.Time
ExpiresAt time.Time
UpdatedAt time.Time `gorm:"autoUpdateTime:false"`
// Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes)
Revoked bool
// UsedTimes indicates how many times the key was used
UsedTimes int
// LastUsed last time the key was used for peer registration
LastUsed time.Time
// AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register
AutoGroups []string `gorm:"serializer:json"`
// UsageLimit indicates the number of times this key can be used to enroll a machine.
// The value of 0 indicates the unlimited usage.
UsageLimit int
// Ephemeral indicate if the peers will be ephemeral or not
Ephemeral bool
}
// Copy copies SetupKey to a new object
func (key *SetupKey) Copy() *SetupKey {
autoGroups := make([]string, len(key.AutoGroups))
copy(autoGroups, key.AutoGroups)
if key.UpdatedAt.IsZero() {
key.UpdatedAt = key.CreatedAt
}
return &SetupKey{
Id: key.Id,
AccountID: key.AccountID,
Key: key.Key,
KeySecret: key.KeySecret,
Name: key.Name,
Type: key.Type,
CreatedAt: key.CreatedAt,
ExpiresAt: key.ExpiresAt,
UpdatedAt: key.UpdatedAt,
Revoked: key.Revoked,
UsedTimes: key.UsedTimes,
LastUsed: key.LastUsed,
AutoGroups: autoGroups,
UsageLimit: key.UsageLimit,
Ephemeral: key.Ephemeral,
}
}
// EventMeta returns activity event meta related to the setup key
func (key *SetupKey) EventMeta() map[string]any {
return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret}
}
// hiddenKey returns the Key value hidden with "*" and a 5 character prefix.
// E.g., "831F6*******************************"
func hiddenKey(key string, length int) string {
prefix := key[0:5]
if length > utf8.RuneCountInString(key) {
length = utf8.RuneCountInString(key) - len(prefix)
}
return prefix + strings.Repeat("*", length)
}
// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now
func (key *SetupKey) IncrementUsage() *SetupKey {
c := key.Copy()
c.UsedTimes++
c.LastUsed = time.Now().UTC()
return c
}
// IsValid is true if the key was not revoked, is not expired and used not more than it was supposed to
func (key *SetupKey) IsValid() bool {
return !key.IsRevoked() && !key.IsExpired() && !key.IsOverUsed()
}
// IsRevoked if key was revoked
func (key *SetupKey) IsRevoked() bool {
return key.Revoked
}
// IsExpired if key was expired
func (key *SetupKey) IsExpired() bool {
if key.ExpiresAt.IsZero() {
return false
}
return time.Now().After(key.ExpiresAt)
}
// IsOverUsed if the key was used too many times. SetupKey.UsageLimit == 0 indicates the unlimited usage.
func (key *SetupKey) IsOverUsed() bool {
limit := key.UsageLimit
if key.Type == SetupKeyOneOff {
limit = 1
}
return limit > 0 && key.UsedTimes >= limit
}
// GenerateSetupKey generates a new setup key
func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string,
usageLimit int, ephemeral bool) (*SetupKey, string) {
key := strings.ToUpper(uuid.New().String())
limit := usageLimit
if t == SetupKeyOneOff {
limit = 1
}
expiresAt := time.Time{}
if validFor != 0 {
expiresAt = time.Now().UTC().Add(validFor)
}
hashedKey := sha256.Sum256([]byte(key))
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
return &SetupKey{
Id: strconv.Itoa(int(Hash(key))),
Key: encodedHashedKey,
KeySecret: hiddenKey(key, 4),
Name: name,
Type: t,
CreatedAt: time.Now().UTC(),
ExpiresAt: expiresAt,
UpdatedAt: time.Now().UTC(),
Revoked: false,
UsedTimes: 0,
AutoGroups: autoGroups,
UsageLimit: limit,
Ephemeral: ephemeral,
}, key
}
// GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration
func GenerateDefaultSetupKey() (*SetupKey, string) {
return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{},
SetupKeyUnlimitedUsage, false)
}
func Hash(s string) uint32 {
h := fnv.New32a()
_, err := h.Write([]byte(s))
if err != nil {
panic(err)
}
return h.Sum32()
}
// CreateSetupKey generates a new setup key with a given name, type, list of groups IDs to auto-assign to peers registered with this key,
// and adds it to the specified account. A list of autoGroups IDs can be empty.
func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) {
func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*types.SetupKey, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
@ -242,22 +69,22 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
return nil, status.NewAdminPermissionError()
}
var setupKey *SetupKey
var setupKey *types.SetupKey
var plainKey string
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil {
return err
}
setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
setupKey, plainKey = types.GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
setupKey.AccountID = accountID
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey)
eventsToStore = append(eventsToStore, events...)
return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey)
return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, setupKey)
})
if err != nil {
return nil, err
@ -278,7 +105,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
// Due to the unique nature of a SetupKey certain properties must not be overwritten
// (e.g. the key itself, creation date, ID, etc).
// These properties are overwritten: AutoGroups, Revoked (only from false to true), and the UpdatedAt. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *types.SetupKey, userID string) (*types.SetupKey, error) {
if keyToSave == nil {
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
}
@ -286,7 +113,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
@ -299,16 +126,16 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
return nil, status.NewAdminPermissionError()
}
var oldKey *SetupKey
var newKey *SetupKey
var oldKey *types.SetupKey
var newKey *types.SetupKey
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil {
return err
}
oldKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id)
oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id)
if err != nil {
return err
}
@ -323,13 +150,13 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
newKey.Revoked = keyToSave.Revoked
newKey.UpdatedAt = time.Now().UTC()
addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups)
removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups)
addedGroups := util.Difference(newKey.AutoGroups, oldKey.AutoGroups)
removedGroups := util.Difference(oldKey.AutoGroups, newKey.AutoGroups)
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey)
eventsToStore = append(eventsToStore, events...)
return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, newKey)
return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, newKey)
})
if err != nil {
return nil, err
@ -347,8 +174,8 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
}
// ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
@ -361,12 +188,12 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID)
}
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
@ -379,7 +206,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return nil, status.NewAdminPermissionError()
}
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID)
if err != nil {
return nil, err
}
@ -394,7 +221,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
// DeleteSetupKey removes the setup key from the account
func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return err
}
@ -407,15 +234,15 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
return status.NewAdminPermissionError()
}
var deletedSetupKey *SetupKey
var deletedSetupKey *types.SetupKey
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID)
if err != nil {
return err
}
return transaction.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID)
return transaction.DeleteSetupKey(ctx, store.LockingStrengthUpdate, accountID, keyID)
})
if err != nil {
return err
@ -426,8 +253,8 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
return nil
}
func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error {
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, autoGroupIDs)
func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, autoGroupIDs)
if err != nil {
return err
}
@ -447,11 +274,11 @@ func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountI
}
// prepareSetupKeyEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() {
func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction store.Store, accountID, userID string, addedGroups, removedGroups []string, key *types.SetupKey) []func() {
var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups)
if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
return nil

View File

@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/types"
)
func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
@ -49,15 +50,15 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
expiresIn := time.Hour
keyName := "my-test-key"
key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, SetupKeyReusable, expiresIn, []string{},
SetupKeyUnlimitedUsage, userID, false)
key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, types.SetupKeyReusable, expiresIn, []string{},
types.SetupKeyUnlimitedUsage, userID, false)
if err != nil {
t.Fatal(err)
}
autoGroups := []string{"group_1", "group_2"}
revoked := true
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &types.SetupKey{
Id: key.Id,
Revoked: revoked,
AutoGroups: autoGroups,
@ -85,7 +86,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
// saving setup key with All group assigned to auto groups should return error
autoGroups = append(autoGroups, groupAll.ID)
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
_, err = manager.SaveSetupKey(context.Background(), account.Id, &types.SetupKey{
Id: key.Id,
Revoked: revoked,
AutoGroups: autoGroups,
@ -167,8 +168,8 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
for _, tCase := range []testCase{testCase1, testCase2, testCase3} {
t.Run(tCase.name, func(t *testing.T) {
key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false)
key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, types.SetupKeyReusable, expiresIn,
tCase.expectedGroups, types.SetupKeyUnlimitedUsage, userID, false)
if tCase.expectedFailure {
if err == nil {
@ -182,7 +183,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
}
assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes,
tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))),
tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(types.Hash(key.Key))),
tCase.expectedUpdatedAt, tCase.expectedGroups, false)
// check the corresponding events that should have been generated
@ -210,7 +211,7 @@ func TestGetSetupKeys(t *testing.T) {
t.Fatal(err)
}
plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", types.SetupKeyReusable, time.Hour, nil, types.SetupKeyUnlimitedUsage, userID, false)
if err != nil {
t.Fatal(err)
}
@ -258,10 +259,10 @@ func TestGenerateDefaultSetupKey(t *testing.T) {
expectedExpiresAt := time.Now().UTC().Add(24 * 30 * time.Hour)
var expectedAutoGroups []string
key, plainKey := GenerateDefaultSetupKey()
key, plainKey := types.GenerateDefaultSetupKey()
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
expectedExpiresAt, strconv.Itoa(int(Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true)
expectedExpiresAt, strconv.Itoa(int(types.Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true)
}
@ -275,48 +276,48 @@ func TestGenerateSetupKey(t *testing.T) {
expectedUpdatedAt := time.Now().UTC()
var expectedAutoGroups []string
key, plain := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
key, plain := types.GenerateSetupKey(expectedName, types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false)
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
expectedExpiresAt, strconv.Itoa(int(Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true)
expectedExpiresAt, strconv.Itoa(int(types.Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true)
}
func TestSetupKey_IsValid(t *testing.T) {
validKey, _ := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
validKey, _ := types.GenerateSetupKey("valid key", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false)
if !validKey.IsValid() {
t.Errorf("expected key to be valid, got invalid %v", validKey)
}
// expired
expiredKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
expiredKey, _ := types.GenerateSetupKey("invalid key", types.SetupKeyOneOff, -time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false)
if expiredKey.IsValid() {
t.Errorf("expected key to be invalid due to expiration, got valid %v", expiredKey)
}
// revoked
revokedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
revokedKey, _ := types.GenerateSetupKey("invalid key", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false)
revokedKey.Revoked = true
if revokedKey.IsValid() {
t.Errorf("expected revoked key to be invalid, got valid %v", revokedKey)
}
// overused
overUsedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
overUsedKey, _ := types.GenerateSetupKey("invalid key", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false)
overUsedKey.UsedTimes = 1
if overUsedKey.IsValid() {
t.Errorf("expected overused key to be invalid, got valid %v", overUsedKey)
}
// overused
reusableKey, _ := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
reusableKey, _ := types.GenerateSetupKey("valid key", types.SetupKeyReusable, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false)
reusableKey.UsedTimes = 99
if !reusableKey.IsValid() {
t.Errorf("expected reusable key to be valid when used many times, got valid %v", reusableKey)
}
}
func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string,
func assertKey(t *testing.T, key *types.SetupKey, expectedName string, expectedRevoke bool, expectedType string,
expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string,
expectedUpdatedAt time.Time, expectedAutoGroups []string, expectHashedKey bool) {
t.Helper()
@ -388,7 +389,7 @@ func isValidBase64SHA256(encodedKey string) bool {
func TestSetupKey_Copy(t *testing.T) {
key, _ := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false)
key, _ := types.GenerateSetupKey("key name", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false)
keyCopy := key.Copy()
assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id,
@ -406,15 +407,15 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
})
assert.NoError(t, err)
policy := &Policy{
policy := &types.Policy{
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"group"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
Action: types.PolicyTrafficActionAccept,
},
},
}
@ -426,7 +427,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
})
var setupKey *SetupKey
var setupKey *types.SetupKey
// Creating setup key should not update account peers and not send peer update
t.Run("creating setup key", func(t *testing.T) {
@ -436,7 +437,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
close(done)
}()
setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, 999, userID, false)
setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", types.SetupKeyReusable, time.Hour, nil, 999, userID, false)
assert.NoError(t, err)
select {
@ -477,7 +478,7 @@ func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t
t.Fatal(err)
}
key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", types.SetupKeyReusable, time.Hour, nil, types.SetupKeyUnlimitedUsage, userID, false)
assert.NoError(t, err)
// revoke the key

View File

@ -1,4 +1,4 @@
package server
package store
import (
"context"
@ -14,6 +14,7 @@ import (
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
@ -22,7 +23,7 @@ const storeFileName = "store.json"
// FileStore represents an account storage backed by a file persisted to disk
type FileStore struct {
Accounts map[string]*Account
Accounts map[string]*types.Account
SetupKeyID2AccountID map[string]string `json:"-"`
PeerKeyID2AccountID map[string]string `json:"-"`
PeerID2AccountID map[string]string `json:"-"`
@ -55,7 +56,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
if _, err := os.Stat(file); os.IsNotExist(err) {
// create a new FileStore if previously didn't exist (e.g. first run)
s := &FileStore{
Accounts: make(map[string]*Account),
Accounts: make(map[string]*types.Account),
mux: sync.Mutex{},
SetupKeyID2AccountID: make(map[string]string),
PeerKeyID2AccountID: make(map[string]string),
@ -92,12 +93,12 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
for accountID, account := range store.Accounts {
if account.Settings == nil {
account.Settings = &Settings{
account.Settings = &types.Settings{
PeerLoginExpirationEnabled: false,
PeerLoginExpiration: DefaultPeerLoginExpiration,
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: DefaultPeerInactivityExpiration,
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
}
}
@ -112,7 +113,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
for _, user := range account.Users {
store.UserID2AccountID[user.Id] = accountID
if user.Issued == "" {
user.Issued = UserIssuedAPI
user.Issued = types.UserIssuedAPI
account.Users[user.Id] = user
}
@ -122,7 +123,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
}
}
if account.Domain != "" && account.DomainCategory == PrivateCategory &&
if account.Domain != "" && account.DomainCategory == types.PrivateCategory &&
account.IsDomainPrimaryAccount {
store.PrivateDomain2AccountID[account.Domain] = accountID
}
@ -134,13 +135,13 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
policy.UpgradeAndFix()
}
if account.Policies == nil {
account.Policies = make([]*Policy, 0)
account.Policies = make([]*types.Policy, 0)
}
// for data migration. Can be removed once most base will be with labels
existingLabels := account.getPeerDNSLabels()
existingLabels := account.GetPeerDNSLabels()
if len(existingLabels) != len(account.Peers) {
addPeerLabelsToAccount(ctx, account, existingLabels)
types.AddPeerLabelsToAccount(ctx, account, existingLabels)
}
// TODO: delete this block after migration
@ -236,7 +237,7 @@ func (s *FileStore) persist(ctx context.Context, file string) error {
}
// GetAllAccounts returns all accounts
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*types.Account) {
s.mux.Lock()
defer s.mux.Unlock()
for _, a := range s.Accounts {
@ -257,6 +258,6 @@ func (s *FileStore) Close(ctx context.Context) error {
}
// GetStoreEngine returns FileStoreEngine
func (s *FileStore) GetStoreEngine() StoreEngine {
func (s *FileStore) GetStoreEngine() Engine {
return FileStoreEngine
}

View File

@ -1,4 +1,4 @@
package server
package store
import (
"context"
@ -15,7 +15,6 @@ import (
"sync"
"time"
"github.com/netbirdio/netbird/management/server/networks"
log "github.com/sirupsen/logrus"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
@ -26,10 +25,14 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account"
nbgroup "github.com/netbirdio/netbird/management/server/group"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
@ -50,7 +53,7 @@ type SqlStore struct {
globalAccountLock sync.Mutex
metrics telemetry.AppMetrics
installationPK int
storeEngine StoreEngine
storeEngine Engine
}
type installation struct {
@ -61,7 +64,7 @@ type installation struct {
type migrationFunc func(*gorm.DB) error
// NewSqlStore creates a new SqlStore instance.
func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetrics) (*SqlStore, error) {
func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics telemetry.AppMetrics) (*SqlStore, error) {
sql, err := db.DB()
if err != nil {
return nil, err
@ -87,10 +90,10 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr
return nil, fmt.Errorf("migrate: %w", err)
}
err = db.AutoMigrate(
&SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &nbgroup.Group{},
&Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &nbgroup.Group{},
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networks.Network{}, &networks.NetworkRouter{}, &networks.NetworkResource{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
)
if err != nil {
return nil, fmt.Errorf("auto migrate: %w", err)
@ -153,7 +156,7 @@ func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (u
return unlock
}
func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error {
func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) error {
start := time.Now()
defer func() {
elapsed := time.Since(start)
@ -203,7 +206,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error {
}
// generateAccountSQLTypes generates the GORM compatible types for the account
func generateAccountSQLTypes(account *Account) {
func generateAccountSQLTypes(account *types.Account) {
for _, key := range account.SetupKeys {
account.SetupKeysG = append(account.SetupKeysG, *key)
}
@ -240,7 +243,7 @@ func generateAccountSQLTypes(account *Account) {
// checkAccountDomainBeforeSave temporary method to troubleshoot an issue with domains getting blank
func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, newDomain string) {
var acc Account
var acc types.Account
var domain string
result := s.db.Model(&acc).Select("domain").Where(idQueryCondition, accountID).First(&domain)
if result.Error != nil {
@ -254,7 +257,7 @@ func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID,
}
}
func (s *SqlStore) DeleteAccount(ctx context.Context, account *Account) error {
func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) error {
start := time.Now()
err := s.db.Transaction(func(tx *gorm.DB) error {
@ -335,14 +338,14 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
}
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
accountCopy := Account{
accountCopy := types.Account{
Domain: domain,
DomainCategory: category,
IsDomainPrimaryAccount: isPrimaryDomain,
}
fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"}
result := s.db.Model(&Account{}).
result := s.db.Model(&types.Account{}).
Select(fieldsToUpdate).
Where(idQueryCondition, accountID).
Updates(&accountCopy)
@ -404,8 +407,8 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
// SaveUsers saves the given list of users to the database.
// It updates existing users if a conflict occurs.
func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
usersToSave := make([]User, 0, len(users))
func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) error {
usersToSave := make([]types.User, 0, len(users))
for _, user := range users {
user.AccountID = accountID
for id, pat := range user.PATs {
@ -425,7 +428,7 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
}
// SaveUser saves the given user to the database.
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
@ -456,7 +459,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
return nil
}
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) {
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
if err != nil {
return nil, err
@ -468,9 +471,9 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
var accountID string
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Select("id").
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory,
strings.ToLower(domain), true, types.PrivateCategory,
).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@ -483,8 +486,8 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
return accountID, nil
}
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
var key SetupKey
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) {
var key types.SetupKey
result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@ -502,7 +505,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
}
func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) {
var token PersonalAccessToken
var token types.PersonalAccessToken
result := s.db.First(&token, "hashed_token = ?", hashedToken)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@ -515,8 +518,8 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
return token.ID, nil
}
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) {
var token PersonalAccessToken
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error) {
var token types.PersonalAccessToken
result := s.db.First(&token, idQueryCondition, tokenID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@ -530,13 +533,13 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
var user User
var user types.User
result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATsG))
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATsG))
for _, pat := range user.PATsG {
user.PATs[pat.ID] = pat.Copy()
}
@ -544,8 +547,8 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return &user, nil
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
var user User
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
var user types.User
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Preload(clause.Associations).First(&user, idQueryCondition, userID)
if result.Error != nil {
@ -558,8 +561,8 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return &user, nil
}
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) {
var users []*User
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) {
var users []*types.User
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@ -586,8 +589,8 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
return groups, nil
}
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) {
var accounts []Account
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) {
var accounts []types.Account
result := s.db.Find(&accounts)
if result.Error != nil {
return all
@ -602,7 +605,7 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) {
return all
}
func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, error) {
func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) {
start := time.Now()
defer func() {
elapsed := time.Since(start)
@ -611,7 +614,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
}
}()
var account Account
var account types.Account
result := s.db.Model(&account).
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
Preload(clause.Associations).
@ -626,15 +629,15 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
for i, policy := range account.Policies {
var rules []*PolicyRule
err := s.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
var rules []*types.PolicyRule
err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
if err != nil {
return nil, status.Errorf(status.NotFound, "rule not found")
}
account.Policies[i].Rules = rules
}
account.SetupKeys = make(map[string]*SetupKey, len(account.SetupKeysG))
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
for _, key := range account.SetupKeysG {
account.SetupKeys[key.Key] = key.Copy()
}
@ -646,9 +649,9 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
}
account.PeersG = nil
account.Users = make(map[string]*User, len(account.UsersG))
account.Users = make(map[string]*types.User, len(account.UsersG))
for _, user := range account.UsersG {
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATs))
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs))
for _, pat := range user.PATsG {
user.PATs[pat.ID] = pat.Copy()
}
@ -677,8 +680,8 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
return &account, nil
}
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
var user User
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) {
var user types.User
result := s.db.Select("account_id").First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@ -694,7 +697,7 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
return s.GetAccount(ctx, user.AccountID)
}
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) {
var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID)
if result.Error != nil {
@ -711,7 +714,7 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
return s.GetAccount(ctx, peer.AccountID)
}
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) {
var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, keyQueryCondition, peerKey)
if result.Error != nil {
@ -744,7 +747,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
var accountID string
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
result := s.db.Model(&types.User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
@ -757,7 +760,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
var accountID string
result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
result := s.db.Model(&types.SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.NewSetupKeyNotFoundError(setupKey)
@ -817,9 +820,9 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
return labels, nil
}
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
var accountNetwork AccountNetwork
if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
var accountNetwork types.AccountNetwork
if err := s.db.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
@ -841,9 +844,9 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
return &peer, nil
}
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
var accountSettings AccountSettings
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) {
var accountSettings types.AccountSettings
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found")
}
@ -854,7 +857,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
var user User
var user types.User
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@ -892,7 +895,7 @@ func (s *SqlStore) Close(_ context.Context) error {
}
// GetStoreEngine returns underlying store engine
func (s *SqlStore) GetStoreEngine() StoreEngine {
func (s *SqlStore) GetStoreEngine() Engine {
return s.storeEngine
}
@ -984,8 +987,8 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore,
return store, nil
}
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
var setupKey SetupKey
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
var setupKey types.SetupKey
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, keyQueryCondition, key)
if result.Error != nil {
@ -999,7 +1002,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
}
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
result := s.db.Model(&SetupKey{}).
result := s.db.Model(&types.SetupKey{}).
Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"),
@ -1116,7 +1119,7 @@ func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStreng
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
return status.Errorf(status.Internal, "failed to increment network serial count in store")
@ -1158,9 +1161,9 @@ func (s *SqlStore) GetDB() *gorm.DB {
return s.db
}
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) {
var accountDNSSettings AccountDNSSettings
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) {
var accountDNSSettings types.AccountDNSSettings
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
First(&accountDNSSettings, idQueryCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@ -1175,7 +1178,7 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
// AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
var accountID string
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
Select("id").First(&accountID, idQueryCondition, id)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@ -1189,8 +1192,8 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
var account Account
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
var account types.Account
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Select("domain", "domain_category").
Where(idQueryCondition, accountID).First(&account)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@ -1297,8 +1300,8 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a
}
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
var policies []*Policy
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) {
var policies []*types.Policy
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Preload(clause.Associations).Find(&policies, accountIDCondition, accountID)
if err := result.Error; err != nil {
@ -1310,8 +1313,8 @@ func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingS
}
// GetPolicyByID retrieves a policy by its ID and account ID.
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) {
var policy *Policy
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) {
var policy *types.Policy
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
First(&policy, accountAndIDQueryCondition, accountID, policyID)
if err := result.Error; err != nil {
@ -1325,7 +1328,7 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng
return policy, nil
}
func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error)
@ -1336,7 +1339,7 @@ func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrengt
}
// SavePolicy saves a policy to the database.
func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error {
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy)
if err := result.Error; err != nil {
@ -1348,7 +1351,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength,
func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID)
Delete(&types.Policy{}, accountAndIDQueryCondition, accountID, policyID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err)
return status.Errorf(status.Internal, "failed to delete policy from store")
@ -1444,8 +1447,8 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt
}
// GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
var setupKeys []*SetupKey
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) {
var setupKeys []*types.SetupKey
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&setupKeys, accountIDCondition, accountID)
if err := result.Error; err != nil {
@ -1457,8 +1460,8 @@ func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength Locking
}
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) {
var setupKey *SetupKey
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) {
var setupKey *types.SetupKey
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
if err := result.Error; err != nil {
@ -1473,7 +1476,7 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre
}
// SaveSetupKey saves a setup key to the database.
func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error {
func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error)
@ -1485,7 +1488,7 @@ func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrengt
// DeleteSetupKey deletes a setup key from the database.
func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&types.SetupKey{}, accountAndIDQueryCondition, accountID, keyID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error)
return status.Errorf(status.Internal, "failed to delete setup key from store")
@ -1585,9 +1588,9 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
}
// SaveDNSSettings saves the DNS settings to the store.
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings})
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
Where(idQueryCondition, accountID).Updates(&types.AccountDNSSettings{DNSSettings: *settings})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save dns settings to store")
@ -1600,8 +1603,8 @@ func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStre
return nil
}
func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networks.Network, error) {
var networks []*networks.Network
func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) {
var networks []*networkTypes.Network
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&networks, accountIDCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get networks from the store: %s", result.Error)
@ -1611,8 +1614,8 @@ func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingS
return networks, nil
}
func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networks.Network, error) {
var network *networks.Network
func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) {
var network *networkTypes.Network
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&network, accountAndIDQueryCondition, accountID, networkID)
if result.Error != nil {
@ -1627,7 +1630,7 @@ func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStren
return network, nil
}
func (s *SqlStore) SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networks.Network) error {
func (s *SqlStore) SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(network)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save network to store: %v", result.Error)
@ -1639,7 +1642,7 @@ func (s *SqlStore) SaveNetwork(ctx context.Context, lockStrength LockingStrength
func (s *SqlStore) DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&networks.Network{}, accountAndIDQueryCondition, accountID, networkID)
Delete(&networkTypes.Network{}, accountAndIDQueryCondition, accountID, networkID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete network from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete network from store")
@ -1652,8 +1655,8 @@ func (s *SqlStore) DeleteNetwork(ctx context.Context, lockStrength LockingStreng
return nil
}
func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*networks.NetworkRouter, error) {
var netRouters []*networks.NetworkRouter
func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) {
var netRouters []*routerTypes.NetworkRouter
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&netRouters, "account_id = ? AND network_id = ?", accountID, netID)
if result.Error != nil {
@ -1664,8 +1667,8 @@ func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength Lo
return netRouters, nil
}
func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*networks.NetworkRouter, error) {
var netRouter *networks.NetworkRouter
func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) {
var netRouter *routerTypes.NetworkRouter
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&netRouter, accountAndIDQueryCondition, accountID, routerID)
if result.Error != nil {
@ -1679,7 +1682,7 @@ func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength Lockin
return netRouter, nil
}
func (s *SqlStore) SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *networks.NetworkRouter) error {
func (s *SqlStore) SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(router)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save network router to store: %v", result.Error)
@ -1691,7 +1694,7 @@ func (s *SqlStore) SaveNetworkRouter(ctx context.Context, lockStrength LockingSt
func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&networks.NetworkRouter{}, accountAndIDQueryCondition, accountID, routerID)
Delete(&routerTypes.NetworkRouter{}, accountAndIDQueryCondition, accountID, routerID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete network router from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete network router from store")
@ -1704,8 +1707,8 @@ func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, lockStrength Locking
return nil
}
func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*networks.NetworkResource, error) {
var netResources []*networks.NetworkResource
func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*resourceTypes.NetworkResource, error) {
var netResources []*resourceTypes.NetworkResource
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&netResources, "account_id = ? AND network_id = ?", accountID, networkID)
if result.Error != nil {
@ -1716,8 +1719,8 @@ func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength
return netResources, nil
}
func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*networks.NetworkResource, error) {
var netResources *networks.NetworkResource
func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) {
var netResources *resourceTypes.NetworkResource
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&netResources, accountAndIDQueryCondition, accountID, resourceID)
if result.Error != nil {
@ -1731,7 +1734,7 @@ func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength Lock
return netResources, nil
}
func (s *SqlStore) SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *networks.NetworkResource) error {
func (s *SqlStore) SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(resource)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save network resource to store: %v", result.Error)
@ -1743,7 +1746,7 @@ func (s *SqlStore) SaveNetworkResource(ctx context.Context, lockStrength Locking
func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&networks.NetworkResource{}, accountAndIDQueryCondition, accountID, resourceID)
Delete(&resourceTypes.NetworkResource{}, accountAndIDQueryCondition, accountID, resourceID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete network resource from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete network resource from store")

View File

@ -1,4 +1,4 @@
package server
package store
import (
"context"
@ -14,18 +14,25 @@ import (
"time"
"github.com/google/uuid"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
route2 "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/status"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
nbroute "github.com/netbirdio/netbird/route"
)
func TestSqlite_NewStore(t *testing.T) {
@ -74,7 +81,7 @@ func runLargeTest(t *testing.T, store Store) {
if err != nil {
t.Fatal(err)
}
setupKey, _ := GenerateDefaultSetupKey()
setupKey, _ := types.GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
const numPerAccount = 6000
for n := 0; n < numPerAccount; n++ {
@ -87,14 +94,14 @@ func runLargeTest(t *testing.T, store Store) {
IP: netIP,
Name: peerID,
DNSLabel: peerID,
UserID: userID,
UserID: "testuser",
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
}
account.Peers[peerID] = peer
group, _ := account.GetGroupAll()
group.Peers = append(group.Peers, peerID)
user := &User{
user := &types.User{
Id: fmt.Sprintf("%s-user-%d", account.Id, n),
AccountID: account.Id,
}
@ -135,7 +142,7 @@ func runLargeTest(t *testing.T, store Store) {
}
account.NameServerGroups[nameserver.ID] = nameserver
setupKey, _ := GenerateDefaultSetupKey()
setupKey, _ := types.GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
}
@ -217,7 +224,7 @@ func TestSqlite_SaveAccount(t *testing.T) {
assert.NoError(t, err)
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
setupKey, _ := GenerateDefaultSetupKey()
setupKey, _ := types.GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
@ -231,7 +238,7 @@ func TestSqlite_SaveAccount(t *testing.T) {
require.NoError(t, err)
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
setupKey, _ = GenerateDefaultSetupKey()
setupKey, _ = types.GenerateDefaultSetupKey()
account2.SetupKeys[setupKey.Key] = setupKey
account2.Peers["testpeer2"] = &nbpeer.Peer{
Key: "peerkey2",
@ -290,14 +297,14 @@ func TestSqlite_DeleteAccount(t *testing.T) {
assert.NoError(t, err)
testUserID := "testuser"
user := NewAdminUser(testUserID)
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
user := types.NewAdminUser(testUserID)
user.PATs = map[string]*types.PersonalAccessToken{"testtoken": {
ID: "testtoken",
Name: "test token",
}}
account := newAccountWithId(context.Background(), "account_id", testUserID, "")
setupKey, _ := GenerateDefaultSetupKey()
setupKey, _ := types.GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
@ -307,7 +314,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
account.Users[testUserID] = user
account.Networks = []*networks.Network{
account.Networks = []*networkTypes.Network{
{
ID: "network_id",
AccountID: account.Id,
@ -315,7 +322,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
Description: "network description",
},
}
account.NetworkRouters = []*networks.NetworkRouter{
account.NetworkRouters = []*routerTypes.NetworkRouter{
{
ID: "router_id",
NetworkID: account.Networks[0].ID,
@ -325,7 +332,7 @@ func TestSqlite_DeleteAccount(t *testing.T) {
Metric: 1,
},
}
account.NetworkResources = []*networks.NetworkResource{
account.NetworkResources = []*resourceTypes.NetworkResource{
{
ID: "resource_id",
NetworkID: account.Networks[0].ID,
@ -367,16 +374,16 @@ func TestSqlite_DeleteAccount(t *testing.T) {
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id")
for _, policy := range account.Policies {
var rules []*PolicyRule
err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
var rules []*types.PolicyRule
err = store.(*SqlStore).db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules")
require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount")
}
for _, accountUser := range account.Users {
var pats []*PersonalAccessToken
err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
var pats []*types.PersonalAccessToken
err = store.(*SqlStore).db.Model(&types.PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token")
require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount")
@ -399,7 +406,7 @@ func TestSqlite_GetAccount(t *testing.T) {
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@ -422,7 +429,7 @@ func TestSqlite_SavePeer(t *testing.T) {
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@ -472,7 +479,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@ -527,7 +534,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) {
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@ -581,7 +588,7 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) {
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@ -604,7 +611,7 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@ -628,7 +635,7 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
}
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@ -664,7 +671,7 @@ func TestMigrate(t *testing.T) {
require.NoError(t, err, "Failed to parse CIDR")
type network struct {
Network
types.Network
Net net.IPNet `gorm:"serializer:gob"`
}
@ -679,7 +686,7 @@ func TestMigrate(t *testing.T) {
}
type account struct {
Account
types.Account
Network *network `gorm:"embedded;embeddedPrefix:network_"`
Peers []peer `gorm:"foreignKey:AccountID;references:id"`
}
@ -739,23 +746,10 @@ func TestMigrate(t *testing.T) {
}
func newSqliteStore(t *testing.T) *SqlStore {
t.Helper()
store, err := NewSqliteStore(context.Background(), t.TempDir(), nil)
t.Cleanup(func() {
store.Close(context.Background())
})
require.NoError(t, err)
require.NotNil(t, store)
return store
}
func newAccount(store Store, id int) error {
str := fmt.Sprintf("%s-%d", uuid.New().String(), id)
account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com")
setupKey, _ := GenerateDefaultSetupKey()
setupKey, _ := types.GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["p"+str] = &nbpeer.Peer{
Key: "peerkey" + str,
@ -794,7 +788,7 @@ func TestPostgresql_SaveAccount(t *testing.T) {
assert.NoError(t, err)
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
setupKey, _ := GenerateDefaultSetupKey()
setupKey, _ := types.GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
@ -808,7 +802,7 @@ func TestPostgresql_SaveAccount(t *testing.T) {
require.NoError(t, err)
account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "")
setupKey, _ = GenerateDefaultSetupKey()
setupKey, _ = types.GenerateDefaultSetupKey()
account2.SetupKeys[setupKey.Key] = setupKey
account2.Peers["testpeer2"] = &nbpeer.Peer{
Key: "peerkey2",
@ -867,14 +861,14 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
assert.NoError(t, err)
testUserID := "testuser"
user := NewAdminUser(testUserID)
user.PATs = map[string]*PersonalAccessToken{"testtoken": {
user := types.NewAdminUser(testUserID)
user.PATs = map[string]*types.PersonalAccessToken{"testtoken": {
ID: "testtoken",
Name: "test token",
}}
account := newAccountWithId(context.Background(), "account_id", testUserID, "")
setupKey, _ := GenerateDefaultSetupKey()
setupKey, _ := types.GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
@ -915,16 +909,16 @@ func TestPostgresql_DeleteAccount(t *testing.T) {
require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id")
for _, policy := range account.Policies {
var rules []*PolicyRule
err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
var rules []*types.PolicyRule
err = store.(*SqlStore).db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules")
require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount")
}
for _, accountUser := range account.Users {
var pats []*PersonalAccessToken
err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
var pats []*types.PersonalAccessToken
err = store.(*SqlStore).db.Model(&types.PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token")
require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount")
@ -938,7 +932,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) {
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@ -979,7 +973,7 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) {
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@ -999,7 +993,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@ -1017,7 +1011,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
}
t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine))
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
@ -1030,7 +1024,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
func TestSqlite_GetTakenIPs(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
defer cleanup()
if err != nil {
t.Fatal(err)
@ -1075,7 +1069,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) {
func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
if err != nil {
return
}
@ -1117,7 +1111,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
func TestSqlite_GetAccountNetwork(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
@ -1140,7 +1134,7 @@ func TestSqlite_GetAccountNetwork(t *testing.T) {
func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
@ -1158,14 +1152,14 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey)
require.NoError(t, err)
assert.Equal(t, encodedHashedKey, setupKey.Key)
assert.Equal(t, hiddenKey(plainKey, 4), setupKey.KeySecret)
assert.Equal(t, types.HiddenKey(plainKey, 4), setupKey.KeySecret)
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID)
assert.Equal(t, "Default key", setupKey.Name)
}
func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
@ -1201,7 +1195,7 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
@ -1232,7 +1226,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
}
func TestSqlite_GetAccoundUsers(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
@ -1246,7 +1240,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) {
}
func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
@ -1292,7 +1286,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) {
}
func TestSqlite_GetGroupByName(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
@ -1306,7 +1300,7 @@ func TestSqlite_GetGroupByName(t *testing.T) {
func Test_DeleteSetupKeySuccessfully(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1322,7 +1316,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) {
func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1334,7 +1328,7 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
}
func TestSqlStore_GetGroupsByIDs(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1377,7 +1371,7 @@ func TestSqlStore_GetGroupsByIDs(t *testing.T) {
}
func TestSqlStore_SaveGroup(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1398,7 +1392,7 @@ func TestSqlStore_SaveGroup(t *testing.T) {
}
func TestSqlStore_SaveGroups(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1423,7 +1417,7 @@ func TestSqlStore_SaveGroups(t *testing.T) {
}
func TestSqlStore_DeleteGroup(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1471,7 +1465,7 @@ func TestSqlStore_DeleteGroup(t *testing.T) {
}
func TestSqlStore_DeleteGroups(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1518,7 +1512,7 @@ func TestSqlStore_DeleteGroups(t *testing.T) {
}
func TestSqlStore_GetPeerByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1564,7 +1558,7 @@ func TestSqlStore_GetPeerByID(t *testing.T) {
}
func TestSqlStore_GetPeersByIDs(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1606,7 +1600,7 @@ func TestSqlStore_GetPeersByIDs(t *testing.T) {
}
func TestSqlStore_GetPostureChecksByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1652,7 +1646,7 @@ func TestSqlStore_GetPostureChecksByID(t *testing.T) {
}
func TestSqlStore_GetPostureChecksByIDs(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1695,7 +1689,7 @@ func TestSqlStore_GetPostureChecksByIDs(t *testing.T) {
}
func TestSqlStore_SavePostureChecks(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1736,7 +1730,7 @@ func TestSqlStore_SavePostureChecks(t *testing.T) {
}
func TestSqlStore_DeletePostureChecks(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1783,7 +1777,7 @@ func TestSqlStore_DeletePostureChecks(t *testing.T) {
}
func TestSqlStore_GetPolicyByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1829,23 +1823,23 @@ func TestSqlStore_GetPolicyByID(t *testing.T) {
}
func TestSqlStore_CreatePolicy(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
policy := &Policy{
policy := &types.Policy{
ID: "policy-id",
AccountID: accountID,
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupC"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
Action: types.PolicyTrafficActionAccept,
},
},
}
@ -1859,7 +1853,7 @@ func TestSqlStore_CreatePolicy(t *testing.T) {
}
func TestSqlStore_SavePolicy(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1882,7 +1876,7 @@ func TestSqlStore_SavePolicy(t *testing.T) {
}
func TestSqlStore_DeletePolicy(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1898,7 +1892,7 @@ func TestSqlStore_DeletePolicy(t *testing.T) {
}
func TestSqlStore_GetDNSSettings(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1942,7 +1936,7 @@ func TestSqlStore_GetDNSSettings(t *testing.T) {
}
func TestSqlStore_SaveDNSSettings(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1961,7 +1955,7 @@ func TestSqlStore_SaveDNSSettings(t *testing.T) {
}
func TestSqlStore_GetAccountNameServerGroups(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -1998,7 +1992,7 @@ func TestSqlStore_GetAccountNameServerGroups(t *testing.T) {
}
func TestSqlStore_GetNameServerByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -2044,7 +2038,7 @@ func TestSqlStore_GetNameServerByID(t *testing.T) {
}
func TestSqlStore_SaveNameServerGroup(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -2076,7 +2070,7 @@ func TestSqlStore_SaveNameServerGroup(t *testing.T) {
}
func TestSqlStore_DeleteNameServerGroup(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -2091,8 +2085,97 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) {
require.Nil(t, nsGroup)
}
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account {
log.WithContext(ctx).Debugf("creating new account")
network := types.NewNetwork()
peers := make(map[string]*nbpeer.Peer)
users := make(map[string]*types.User)
routes := make(map[nbroute.ID]*nbroute.Route)
setupKeys := map[string]*types.SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
owner := types.NewOwnerUser(userID)
owner.AccountID = accountID
users[userID] = owner
dnsSettings := types.DNSSettings{
DisabledManagementGroups: make([]string, 0),
}
log.WithContext(ctx).Debugf("created new account %s", accountID)
acc := &types.Account{
Id: accountID,
CreatedAt: time.Now().UTC(),
SetupKeys: setupKeys,
Network: network,
Peers: peers,
Users: users,
CreatedBy: userID,
Domain: domain,
Routes: routes,
NameServerGroups: nameServersGroups,
DNSSettings: dnsSettings,
Settings: &types.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
GroupsPropagationEnabled: true,
RegularUsersViewBlocked: true,
PeerInactivityExpirationEnabled: false,
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
},
}
if err := addAllGroup(acc); err != nil {
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
}
return acc
}
// addAllGroup to account object if it doesn't exist
func addAllGroup(account *types.Account) error {
if len(account.Groups) == 0 {
allGroup := &nbgroup.Group{
ID: xid.New().String(),
Name: "All",
Issued: nbgroup.GroupIssuedAPI,
}
for _, peer := range account.Peers {
allGroup.Peers = append(allGroup.Peers, peer.ID)
}
account.Groups = map[string]*nbgroup.Group{allGroup.ID: allGroup}
id := xid.New().String()
defaultPolicy := &types.Policy{
ID: id,
Name: types.DefaultRuleName,
Description: types.DefaultRuleDescription,
Enabled: true,
Rules: []*types.PolicyRule{
{
ID: id,
Name: types.DefaultRuleName,
Description: types.DefaultRuleDescription,
Enabled: true,
Sources: []string{allGroup.ID},
Destinations: []string{allGroup.ID},
Bidirectional: true,
Protocol: types.PolicyRuleProtocolALL,
Action: types.PolicyTrafficActionAccept,
},
},
}
account.Policies = []*types.Policy{defaultPolicy}
}
return nil
}
func TestSqlStore_GetAccountNetworks(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -2124,7 +2207,7 @@ func TestSqlStore_GetAccountNetworks(t *testing.T) {
}
func TestSqlStore_GetNetworkByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -2170,12 +2253,12 @@ func TestSqlStore_GetNetworkByID(t *testing.T) {
}
func TestSqlStore_SaveNetwork(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
network := &networks.Network{
network := &networkTypes.Network{
ID: "net-id",
AccountID: accountID,
Name: "net",
@ -2190,7 +2273,7 @@ func TestSqlStore_SaveNetwork(t *testing.T) {
}
func TestSqlStore_DeleteNetwork(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -2209,7 +2292,7 @@ func TestSqlStore_DeleteNetwork(t *testing.T) {
}
func TestSqlStore_GetNetworkRoutersByNetID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -2242,7 +2325,7 @@ func TestSqlStore_GetNetworkRoutersByNetID(t *testing.T) {
}
func TestSqlStore_GetNetworkRouterByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -2288,14 +2371,14 @@ func TestSqlStore_GetNetworkRouterByID(t *testing.T) {
}
func TestSqlStore_SaveNetworkRouter(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
networkID := "ct286bi7qv930dsrrug0"
netRouter, err := networks.NewNetworkRouter(accountID, networkID, "", []string{"net-router-grp"}, true, 0)
netRouter, err := routerTypes.NewNetworkRouter(accountID, networkID, "", []string{"net-router-grp"}, true, 0)
require.NoError(t, err)
err = store.SaveNetworkRouter(context.Background(), LockingStrengthUpdate, netRouter)
@ -2307,7 +2390,7 @@ func TestSqlStore_SaveNetworkRouter(t *testing.T) {
}
func TestSqlStore_DeleteNetworkRouter(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -2326,7 +2409,7 @@ func TestSqlStore_DeleteNetworkRouter(t *testing.T) {
}
func TestSqlStore_GetNetworkResourcesByNetID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -2359,7 +2442,7 @@ func TestSqlStore_GetNetworkResourcesByNetID(t *testing.T) {
}
func TestSqlStore_GetNetworkResourceByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
@ -2405,14 +2488,14 @@ func TestSqlStore_GetNetworkResourceByID(t *testing.T) {
}
func TestSqlStore_SaveNetworkResource(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
networkID := "ct286bi7qv930dsrrug0"
netResource, err := networks.NewNetworkResource(accountID, networkID, "resource-name", "", "example.com")
netResource, err := resourceTypes.NewNetworkResource(accountID, networkID, "resource-name", "", "example.com")
require.NoError(t, err)
err = store.SaveNetworkResource(context.Background(), LockingStrengthUpdate, netResource)
@ -2424,7 +2507,7 @@ func TestSqlStore_SaveNetworkResource(t *testing.T) {
}
func TestSqlStore_DeleteNetworkResource(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)

View File

@ -1,4 +1,4 @@
package server
package store
import (
"context"
@ -13,12 +13,12 @@ import (
"strings"
"time"
"github.com/netbirdio/netbird/management/server/networks"
log "github.com/sirupsen/logrus"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/types"
nbgroup "github.com/netbirdio/netbird/management/server/group"
@ -26,6 +26,9 @@ import (
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/management/server/migration"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/testutil"
@ -42,31 +45,31 @@ const (
)
type Store interface {
GetAllAccounts(ctx context.Context) []*Account
GetAccount(ctx context.Context, accountID string) (*Account, error)
GetAllAccounts(ctx context.Context) []*types.Account
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
GetAccountByUser(ctx context.Context, userID string) (*types.Account, error)
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error)
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByUserID(userID string) (string, error)
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) // todo use key hash later
GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error)
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error)
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error)
SaveAccount(ctx context.Context, account *types.Account) error
DeleteAccount(ctx context.Context, account *types.Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error)
SaveUsers(accountID string, users map[string]*types.User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
@ -81,10 +84,10 @@ type Store interface {
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error)
CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error)
CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
@ -106,11 +109,11 @@ type Store interface {
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error)
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error)
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error)
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
@ -123,7 +126,7 @@ type Store interface {
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error)
GetInstallationID() string
SaveInstallationID(ctx context.Context, ID string) error
@ -137,45 +140,45 @@ type Store interface {
// Close should close the store persisting all unsaved data.
Close(ctx context.Context) error
// GetStoreEngine should return StoreEngine of the current store implementation.
// GetStoreEngine should return Engine of the current store implementation.
// This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine
GetStoreEngine() Engine
ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networks.Network, error)
GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networks.Network, error)
SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networks.Network) error
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error)
SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error
DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error
GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*networks.NetworkRouter, error)
GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*networks.NetworkRouter, error)
SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *networks.NetworkRouter) error
GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error)
GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error)
SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error
DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error
GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*networks.NetworkResource, error)
GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*networks.NetworkResource, error)
SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *networks.NetworkResource) error
GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error)
GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error)
SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error
DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error
}
type StoreEngine string
type Engine string
const (
FileStoreEngine StoreEngine = "jsonfile"
SqliteStoreEngine StoreEngine = "sqlite"
PostgresStoreEngine StoreEngine = "postgres"
FileStoreEngine Engine = "jsonfile"
SqliteStoreEngine Engine = "sqlite"
PostgresStoreEngine Engine = "postgres"
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
)
func getStoreEngineFromEnv() StoreEngine {
func getStoreEngineFromEnv() Engine {
// NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file.
kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE")
if !ok {
return ""
}
value := StoreEngine(strings.ToLower(kind))
value := Engine(strings.ToLower(kind))
if value == SqliteStoreEngine || value == PostgresStoreEngine {
return value
}
@ -187,7 +190,7 @@ func getStoreEngineFromEnv() StoreEngine {
// If no engine is specified, it attempts to retrieve it from the environment.
// If still not specified, it defaults to using SQLite.
// Additionally, it handles the migration from a JSON store file to SQLite if applicable.
func getStoreEngine(ctx context.Context, dataDir string, kind StoreEngine) StoreEngine {
func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine {
if kind == "" {
kind = getStoreEngineFromEnv()
if kind == "" {
@ -213,7 +216,7 @@ func getStoreEngine(ctx context.Context, dataDir string, kind StoreEngine) Store
}
// NewStore creates a new store based on the provided engine type, data directory, and telemetry metrics
func NewStore(ctx context.Context, kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (Store, error) {
func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) {
kind = getStoreEngine(ctx, dataDir, kind)
if err := checkFileStoreEngine(kind, dataDir); err != nil {
@ -232,7 +235,7 @@ func NewStore(ctx context.Context, kind StoreEngine, dataDir string, metrics tel
}
}
func checkFileStoreEngine(kind StoreEngine, dataDir string) error {
func checkFileStoreEngine(kind Engine, dataDir string) error {
if kind == FileStoreEngine {
storeFile := filepath.Join(dataDir, storeFileName)
if util.FileExists(storeFile) {
@ -259,7 +262,7 @@ func migrate(ctx context.Context, db *gorm.DB) error {
func getMigrations(ctx context.Context) []migrationFunc {
return []migrationFunc{
func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](ctx, db, "network_net")
return migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](ctx, db, "network_net")
},
func(db *gorm.DB) error {
return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](ctx, db, "network")
@ -274,7 +277,7 @@ func getMigrations(ctx context.Context) []migrationFunc {
return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "ip", "idx_peers_account_id_ip")
},
func(db *gorm.DB) error {
return migration.MigrateSetupKeyToHashedSetupKey[SetupKey](ctx, db)
return migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](ctx, db)
},
}
}

View File

@ -1,4 +1,4 @@
package server
package store
import (
"context"
@ -76,11 +76,3 @@ func BenchmarkTest_StoreRead(b *testing.B) {
})
}
}
func newStore(t *testing.T) Store {
t.Helper()
store := newSqliteStore(t)
return store
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,16 @@
package types
// DNSSettings defines dns settings at the account level
type DNSSettings struct {
// DisabledManagementGroups groups whose DNS management is disabled
DisabledManagementGroups []string `gorm:"serializer:json"`
}
// Copy returns a copy of the DNS settings
func (d DNSSettings) Copy() DNSSettings {
settings := DNSSettings{
DisabledManagementGroups: make([]string, len(d.DisabledManagementGroups)),
}
copy(settings.DisabledManagementGroups, d.DisabledManagementGroups)
return settings
}

View File

@ -0,0 +1,129 @@
package types
import (
"context"
"fmt"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
nbroute "github.com/netbirdio/netbird/route"
)
const (
FirewallRuleDirectionIN = 0
FirewallRuleDirectionOUT = 1
)
// FirewallRule is a rule of the firewall.
type FirewallRule struct {
// PeerIP of the peer
PeerIP string
// Direction of the traffic
Direction int
// Action of the traffic
Action string
// Protocol of the traffic
Protocol string
// Port of the traffic
Port string
}
// generateRouteFirewallRules generates a list of firewall rules for a given route.
func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule {
rulesExists := make(map[string]struct{})
rules := make([]*RouteFirewallRule, 0)
sourceRanges := make([]string, 0, len(groupPeers))
for _, peer := range groupPeers {
if peer == nil {
continue
}
sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP))
}
baseRule := RouteFirewallRule{
SourceRanges: sourceRanges,
Action: string(rule.Action),
Destination: route.Network.String(),
Protocol: string(rule.Protocol),
IsDynamic: route.IsDynamic(),
}
// generate rule for port range
if len(rule.Ports) == 0 {
rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...)
} else {
rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...)
}
// TODO: generate IPv6 rules for dynamic routes
return rules
}
// generateRulesForPeer generates rules for a given peer based on ports and port ranges.
func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule {
rules := make([]*RouteFirewallRule, 0)
ruleIDBase := generateRuleIDBase(rule, baseRule)
if len(rule.Ports) == 0 {
if len(rule.PortRanges) == 0 {
if _, ok := rulesExists[ruleIDBase]; !ok {
rulesExists[ruleIDBase] = struct{}{}
rules = append(rules, &baseRule)
}
} else {
for _, portRange := range rule.PortRanges {
ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End)
if _, ok := rulesExists[ruleID]; !ok {
rulesExists[ruleID] = struct{}{}
pr := baseRule
pr.PortRange = portRange
rules = append(rules, &pr)
}
}
}
return rules
}
return rules
}
// generateRulesWithPorts generates rules when specific ports are provided.
func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule {
rules := make([]*RouteFirewallRule, 0)
ruleIDBase := generateRuleIDBase(rule, baseRule)
for _, port := range rule.Ports {
ruleID := ruleIDBase + port
if _, ok := rulesExists[ruleID]; ok {
continue
}
rulesExists[ruleID] = struct{}{}
pr := baseRule
p, err := strconv.ParseUint(port, 10, 16)
if err != nil {
log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID)
continue
}
pr.Port = uint16(p)
rules = append(rules, &pr)
}
return rules
}
// generateRuleIDBase generates the base rule ID for checking duplicates.
func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string {
return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(FirewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action
}

View File

@ -1,4 +1,4 @@
package server
package types
import (
"math/rand"
@ -43,7 +43,7 @@ type Network struct {
// Used to synchronize state to the client apps.
Serial uint64
mu sync.Mutex `json:"-" gorm:"-"`
Mu sync.Mutex `json:"-" gorm:"-"`
}
// NewNetwork creates a new Network initializing it with a Serial=0
@ -66,15 +66,15 @@ func NewNetwork() *Network {
// IncSerial increments Serial by 1 reflecting that the network state has been changed
func (n *Network) IncSerial() {
n.mu.Lock()
defer n.mu.Unlock()
n.Mu.Lock()
defer n.Mu.Unlock()
n.Serial++
}
// CurrentSerial returns the Network.Serial of the network (latest state id)
func (n *Network) CurrentSerial() uint64 {
n.mu.Lock()
defer n.mu.Unlock()
n.Mu.Lock()
defer n.Mu.Unlock()
return n.Serial
}

View File

@ -1,4 +1,4 @@
package server
package types
import (
"net"

View File

@ -1,4 +1,4 @@
package server
package types
import (
"crypto/sha256"

View File

@ -1,4 +1,4 @@
package server
package types
import (
"crypto/sha256"

View File

@ -0,0 +1,116 @@
package types
const (
// PolicyTrafficActionAccept indicates that the traffic is accepted
PolicyTrafficActionAccept = PolicyTrafficActionType("accept")
// PolicyTrafficActionDrop indicates that the traffic is dropped
PolicyTrafficActionDrop = PolicyTrafficActionType("drop")
)
const (
// PolicyRuleProtocolALL type of traffic
PolicyRuleProtocolALL = PolicyRuleProtocolType("all")
// PolicyRuleProtocolTCP type of traffic
PolicyRuleProtocolTCP = PolicyRuleProtocolType("tcp")
// PolicyRuleProtocolUDP type of traffic
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
// PolicyRuleProtocolICMP type of traffic
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
)
const (
// PolicyRuleFlowDirect allows traffic from source to destination
PolicyRuleFlowDirect = PolicyRuleDirection("direct")
// PolicyRuleFlowBidirect allows traffic to both directions
PolicyRuleFlowBidirect = PolicyRuleDirection("bidirect")
)
const (
// DefaultRuleName is a name for the Default rule that is created for every account
DefaultRuleName = "Default"
// DefaultRuleDescription is a description for the Default rule that is created for every account
DefaultRuleDescription = "This is a default rule that allows connections between all the resources"
// DefaultPolicyName is a name for the Default policy that is created for every account
DefaultPolicyName = "Default"
// DefaultPolicyDescription is a description for the Default policy that is created for every account
DefaultPolicyDescription = "This is a default policy that allows connections between all the resources"
)
// PolicyUpdateOperation operation object with type and values to be applied
type PolicyUpdateOperation struct {
Type PolicyUpdateOperationType
Values []string
}
// Policy of the Rego query
type Policy struct {
// ID of the policy'
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
// Name of the Policy
Name string
// Description of the policy visible in the UI
Description string
// Enabled status of the policy
Enabled bool
// Rules of the policy
Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id;constraint:OnDelete:CASCADE;"`
// SourcePostureChecks are ID references to Posture checks for policy source groups
SourcePostureChecks []string `gorm:"serializer:json"`
}
// Copy returns a copy of the policy.
func (p *Policy) Copy() *Policy {
c := &Policy{
ID: p.ID,
AccountID: p.AccountID,
Name: p.Name,
Description: p.Description,
Enabled: p.Enabled,
Rules: make([]*PolicyRule, len(p.Rules)),
SourcePostureChecks: make([]string, len(p.SourcePostureChecks)),
}
for i, r := range p.Rules {
c.Rules[i] = r.Copy()
}
copy(c.SourcePostureChecks, p.SourcePostureChecks)
return c
}
// EventMeta returns activity event meta related to this policy
func (p *Policy) EventMeta() map[string]any {
return map[string]any{"name": p.Name}
}
// UpgradeAndFix different version of policies to latest version
func (p *Policy) UpgradeAndFix() {
for _, r := range p.Rules {
// start migrate from version v0.20.3
if r.Protocol == "" {
r.Protocol = PolicyRuleProtocolALL
}
if r.Protocol == PolicyRuleProtocolALL && !r.Bidirectional {
r.Bidirectional = true
}
// -- v0.20.4
}
}
// RuleGroups returns a list of all groups referenced in the policy's rules,
// including sources and destinations.
func (p *Policy) RuleGroups() []string {
groups := make([]string, 0)
for _, rule := range p.Rules {
groups = append(groups, rule.Sources...)
groups = append(groups, rule.Destinations...)
}
return groups
}

View File

@ -0,0 +1,81 @@
package types
// PolicyUpdateOperationType operation type
type PolicyUpdateOperationType int
// PolicyTrafficActionType action type for the firewall
type PolicyTrafficActionType string
// PolicyRuleProtocolType type of traffic
type PolicyRuleProtocolType string
// PolicyRuleDirection direction of traffic
type PolicyRuleDirection string
// RulePortRange represents a range of ports for a firewall rule.
type RulePortRange struct {
Start uint16
End uint16
}
// PolicyRule is the metadata of the policy
type PolicyRule struct {
// ID of the policy rule
ID string `gorm:"primaryKey"`
// PolicyID is a reference to Policy that this object belongs
PolicyID string `json:"-" gorm:"index"`
// Name of the rule visible in the UI
Name string
// Description of the rule visible in the UI
Description string
// Enabled status of rule in the system
Enabled bool
// Action policy accept or drops packets
Action PolicyTrafficActionType
// Destinations policy destination groups
Destinations []string `gorm:"serializer:json"`
// Sources policy source groups
Sources []string `gorm:"serializer:json"`
// Bidirectional define if the rule is applicable in both directions, sources, and destinations
Bidirectional bool
// Protocol type of the traffic
Protocol PolicyRuleProtocolType
// Ports or it ranges list
Ports []string `gorm:"serializer:json"`
// PortRanges a list of port ranges.
PortRanges []RulePortRange `gorm:"serializer:json"`
}
// Copy returns a copy of a policy rule
func (pm *PolicyRule) Copy() *PolicyRule {
rule := &PolicyRule{
ID: pm.ID,
PolicyID: pm.PolicyID,
Name: pm.Name,
Description: pm.Description,
Enabled: pm.Enabled,
Action: pm.Action,
Destinations: make([]string, len(pm.Destinations)),
Sources: make([]string, len(pm.Sources)),
Bidirectional: pm.Bidirectional,
Protocol: pm.Protocol,
Ports: make([]string, len(pm.Ports)),
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
}
copy(rule.Destinations, pm.Destinations)
copy(rule.Sources, pm.Sources)
copy(rule.Ports, pm.Ports)
copy(rule.PortRanges, pm.PortRanges)
return rule
}

View File

@ -0,0 +1,25 @@
package types
// RouteFirewallRule a firewall rule applicable for a routed network.
type RouteFirewallRule struct {
// SourceRanges IP ranges of the routing peers.
SourceRanges []string
// Action of the traffic when the rule is applicable
Action string
// Destination a network prefix for the routed traffic
Destination string
// Protocol of the traffic
Protocol string
// Port of the traffic
Port uint16
// PortRange represents the range of ports for a firewall rule
PortRange RulePortRange
// isDynamic indicates whether the rule is for DNS routing
IsDynamic bool
}

View File

@ -0,0 +1,63 @@
package types
import (
"time"
"github.com/netbirdio/netbird/management/server/account"
)
// Settings represents Account settings structure that can be modified via API and Dashboard
type Settings struct {
// PeerLoginExpirationEnabled globally enables or disables peer login expiration
PeerLoginExpirationEnabled bool
// PeerLoginExpiration is a setting that indicates when peer login expires.
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
PeerLoginExpiration time.Duration
// PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration
PeerInactivityExpirationEnabled bool
// PeerInactivityExpiration is a setting that indicates when peer inactivity expires.
// Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true.
PeerInactivityExpiration time.Duration
// RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements
RegularUsersViewBlocked bool
// GroupsPropagationEnabled allows to propagate auto groups from the user to the peer
GroupsPropagationEnabled bool
// JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName
// and add it to account groups.
JWTGroupsEnabled bool
// JWTGroupsClaimName from which we extract groups name to add it to account groups
JWTGroupsClaimName string
// JWTAllowGroups list of groups to which users are allowed access
JWTAllowGroups []string `gorm:"serializer:json"`
// Extra is a dictionary of Account settings
Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"`
}
// Copy copies the Settings struct
func (s *Settings) Copy() *Settings {
settings := &Settings{
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
PeerLoginExpiration: s.PeerLoginExpiration,
JWTGroupsEnabled: s.JWTGroupsEnabled,
JWTGroupsClaimName: s.JWTGroupsClaimName,
GroupsPropagationEnabled: s.GroupsPropagationEnabled,
JWTAllowGroups: s.JWTAllowGroups,
RegularUsersViewBlocked: s.RegularUsersViewBlocked,
PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled,
PeerInactivityExpiration: s.PeerInactivityExpiration,
}
if s.Extra != nil {
settings.Extra = s.Extra.Copy()
}
return settings
}

View File

@ -0,0 +1,181 @@
package types
import (
"crypto/sha256"
b64 "encoding/base64"
"hash/fnv"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/google/uuid"
)
const (
// SetupKeyReusable is a multi-use key (can be used for multiple machines)
SetupKeyReusable SetupKeyType = "reusable"
// SetupKeyOneOff is a single use key (can be used only once)
SetupKeyOneOff SetupKeyType = "one-off"
// DefaultSetupKeyDuration = 1 month
DefaultSetupKeyDuration = 24 * 30 * time.Hour
// DefaultSetupKeyName is a default name of the default setup key
DefaultSetupKeyName = "Default key"
// SetupKeyUnlimitedUsage indicates an unlimited usage of a setup key
SetupKeyUnlimitedUsage = 0
)
// SetupKeyType is the type of setup key
type SetupKeyType string
// SetupKey represents a pre-authorized key used to register machines (peers)
type SetupKey struct {
Id string
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
Key string
KeySecret string
Name string
Type SetupKeyType
CreatedAt time.Time
ExpiresAt time.Time
UpdatedAt time.Time `gorm:"autoUpdateTime:false"`
// Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes)
Revoked bool
// UsedTimes indicates how many times the key was used
UsedTimes int
// LastUsed last time the key was used for peer registration
LastUsed time.Time
// AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register
AutoGroups []string `gorm:"serializer:json"`
// UsageLimit indicates the number of times this key can be used to enroll a machine.
// The value of 0 indicates the unlimited usage.
UsageLimit int
// Ephemeral indicate if the peers will be ephemeral or not
Ephemeral bool
}
// Copy copies SetupKey to a new object
func (key *SetupKey) Copy() *SetupKey {
autoGroups := make([]string, len(key.AutoGroups))
copy(autoGroups, key.AutoGroups)
if key.UpdatedAt.IsZero() {
key.UpdatedAt = key.CreatedAt
}
return &SetupKey{
Id: key.Id,
AccountID: key.AccountID,
Key: key.Key,
KeySecret: key.KeySecret,
Name: key.Name,
Type: key.Type,
CreatedAt: key.CreatedAt,
ExpiresAt: key.ExpiresAt,
UpdatedAt: key.UpdatedAt,
Revoked: key.Revoked,
UsedTimes: key.UsedTimes,
LastUsed: key.LastUsed,
AutoGroups: autoGroups,
UsageLimit: key.UsageLimit,
Ephemeral: key.Ephemeral,
}
}
// EventMeta returns activity event meta related to the setup key
func (key *SetupKey) EventMeta() map[string]any {
return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret}
}
// HiddenKey returns the Key value hidden with "*" and a 5 character prefix.
// E.g., "831F6*******************************"
func HiddenKey(key string, length int) string {
prefix := key[0:5]
if length > utf8.RuneCountInString(key) {
length = utf8.RuneCountInString(key) - len(prefix)
}
return prefix + strings.Repeat("*", length)
}
// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now
func (key *SetupKey) IncrementUsage() *SetupKey {
c := key.Copy()
c.UsedTimes++
c.LastUsed = time.Now().UTC()
return c
}
// IsValid is true if the key was not revoked, is not expired and used not more than it was supposed to
func (key *SetupKey) IsValid() bool {
return !key.IsRevoked() && !key.IsExpired() && !key.IsOverUsed()
}
// IsRevoked if key was revoked
func (key *SetupKey) IsRevoked() bool {
return key.Revoked
}
// IsExpired if key was expired
func (key *SetupKey) IsExpired() bool {
if key.ExpiresAt.IsZero() {
return false
}
return time.Now().After(key.ExpiresAt)
}
// IsOverUsed if the key was used too many times. SetupKey.UsageLimit == 0 indicates the unlimited usage.
func (key *SetupKey) IsOverUsed() bool {
limit := key.UsageLimit
if key.Type == SetupKeyOneOff {
limit = 1
}
return limit > 0 && key.UsedTimes >= limit
}
// GenerateSetupKey generates a new setup key
func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string,
usageLimit int, ephemeral bool) (*SetupKey, string) {
key := strings.ToUpper(uuid.New().String())
limit := usageLimit
if t == SetupKeyOneOff {
limit = 1
}
expiresAt := time.Time{}
if validFor != 0 {
expiresAt = time.Now().UTC().Add(validFor)
}
hashedKey := sha256.Sum256([]byte(key))
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
return &SetupKey{
Id: strconv.Itoa(int(Hash(key))),
Key: encodedHashedKey,
KeySecret: HiddenKey(key, 4),
Name: name,
Type: t,
CreatedAt: time.Now().UTC(),
ExpiresAt: expiresAt,
UpdatedAt: time.Now().UTC(),
Revoked: false,
UsedTimes: 0,
AutoGroups: autoGroups,
UsageLimit: limit,
Ephemeral: ephemeral,
}, key
}
// GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration
func GenerateDefaultSetupKey() (*SetupKey, string) {
return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{},
SetupKeyUnlimitedUsage, false)
}
func Hash(s string) uint32 {
h := fnv.New32a()
_, err := h.Write([]byte(s))
if err != nil {
panic(err)
}
return h.Sum32()
}

View File

@ -0,0 +1,231 @@
package types
import (
"fmt"
"strings"
"time"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference"
)
const (
UserRoleOwner UserRole = "owner"
UserRoleAdmin UserRole = "admin"
UserRoleUser UserRole = "user"
UserRoleUnknown UserRole = "unknown"
UserRoleBillingAdmin UserRole = "billing_admin"
UserStatusActive UserStatus = "active"
UserStatusDisabled UserStatus = "disabled"
UserStatusInvited UserStatus = "invited"
UserIssuedAPI = "api"
UserIssuedIntegration = "integration"
)
// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown
func StrRoleToUserRole(strRole string) UserRole {
switch strings.ToLower(strRole) {
case "owner":
return UserRoleOwner
case "admin":
return UserRoleAdmin
case "user":
return UserRoleUser
case "billing_admin":
return UserRoleBillingAdmin
default:
return UserRoleUnknown
}
}
// UserStatus is the status of a User
type UserStatus string
// UserRole is the role of a User
type UserRole string
type UserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"`
NonDeletable bool `json:"non_deletable"`
LastLogin time.Time `json:"last_login"`
Issued string `json:"issued"`
IntegrationReference integration_reference.IntegrationReference `json:"-"`
Permissions UserPermissions `json:"permissions"`
}
type UserPermissions struct {
DashboardView string `json:"dashboard_view"`
}
// User represents a user of the system
type User struct {
Id string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
Role UserRole
IsServiceUser bool
// NonDeletable indicates whether the service user can be deleted
NonDeletable bool
// ServiceUserName is only set if IsServiceUser is true
ServiceUserName string
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
AutoGroups []string `gorm:"serializer:json"`
PATs map[string]*PersonalAccessToken `gorm:"-"`
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"`
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
Blocked bool
// LastLogin is the last time the user logged in to IdP
LastLogin time.Time
// CreatedAt records the time the user was created
CreatedAt time.Time
// Issued of the user
Issued string `gorm:"default:api"`
IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
}
// IsBlocked returns true if the user is blocked, false otherwise
func (u *User) IsBlocked() bool {
return u.Blocked
}
func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool {
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
}
// HasAdminPower returns true if the user has admin or owner roles, false otherwise
func (u *User) HasAdminPower() bool {
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
}
// IsAdminOrServiceUser checks if the user has admin power or is a service user.
func (u *User) IsAdminOrServiceUser() bool {
return u.HasAdminPower() || u.IsServiceUser
}
// IsRegularUser checks if the user is a regular user.
func (u *User) IsRegularUser() bool {
return !u.HasAdminPower() && !u.IsServiceUser
}
// ToUserInfo converts a User object to a UserInfo object.
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
autoGroups := u.AutoGroups
if autoGroups == nil {
autoGroups = []string{}
}
dashboardViewPermissions := "full"
if !u.HasAdminPower() {
dashboardViewPermissions = "limited"
if settings.RegularUsersViewBlocked {
dashboardViewPermissions = "blocked"
}
}
if userData == nil {
return &UserInfo{
ID: u.Id,
Email: "",
Name: u.ServiceUserName,
Role: string(u.Role),
AutoGroups: u.AutoGroups,
Status: string(UserStatusActive),
IsServiceUser: u.IsServiceUser,
IsBlocked: u.Blocked,
LastLogin: u.LastLogin,
Issued: u.Issued,
Permissions: UserPermissions{
DashboardView: dashboardViewPermissions,
},
}, nil
}
if userData.ID != u.Id {
return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id)
}
userStatus := UserStatusActive
if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite {
userStatus = UserStatusInvited
}
return &UserInfo{
ID: u.Id,
Email: userData.Email,
Name: userData.Name,
Role: string(u.Role),
AutoGroups: autoGroups,
Status: string(userStatus),
IsServiceUser: u.IsServiceUser,
IsBlocked: u.Blocked,
LastLogin: u.LastLogin,
Issued: u.Issued,
Permissions: UserPermissions{
DashboardView: dashboardViewPermissions,
},
}, nil
}
// Copy the user
func (u *User) Copy() *User {
autoGroups := make([]string, len(u.AutoGroups))
copy(autoGroups, u.AutoGroups)
pats := make(map[string]*PersonalAccessToken, len(u.PATs))
for k, v := range u.PATs {
pats[k] = v.Copy()
}
return &User{
Id: u.Id,
AccountID: u.AccountID,
Role: u.Role,
AutoGroups: autoGroups,
IsServiceUser: u.IsServiceUser,
NonDeletable: u.NonDeletable,
ServiceUserName: u.ServiceUserName,
PATs: pats,
Blocked: u.Blocked,
LastLogin: u.LastLogin,
CreatedAt: u.CreatedAt,
Issued: u.Issued,
IntegrationReference: u.IntegrationReference,
}
}
// NewUser creates a new user
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
return &User{
Id: id,
Role: role,
IsServiceUser: isServiceUser,
NonDeletable: nonDeletable,
ServiceUserName: serviceUserName,
AutoGroups: autoGroups,
Issued: issued,
CreatedAt: time.Now().UTC(),
}
}
// NewRegularUser creates a new user with role UserRoleUser
func NewRegularUser(id string) *User {
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
}
// NewAdminUser creates a new user with role UserRoleAdmin
func NewAdminUser(id string) *User {
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
}
// NewOwnerUser creates a new user with role UserRoleOwner
func NewOwnerUser(id string) *User {
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI)
}

View File

@ -9,13 +9,14 @@ import (
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
)
const channelBufferSize = 100
type UpdateMessage struct {
Update *proto.SyncResponse
NetworkMap *NetworkMap
NetworkMap *types.NetworkMap
}
type PeersUpdateManager struct {

View File

@ -15,215 +15,16 @@ import (
nbContext "github.com/netbirdio/netbird/management/server/context"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
)
const (
UserRoleOwner UserRole = "owner"
UserRoleAdmin UserRole = "admin"
UserRoleUser UserRole = "user"
UserRoleUnknown UserRole = "unknown"
UserRoleBillingAdmin UserRole = "billing_admin"
UserStatusActive UserStatus = "active"
UserStatusDisabled UserStatus = "disabled"
UserStatusInvited UserStatus = "invited"
UserIssuedAPI = "api"
UserIssuedIntegration = "integration"
)
// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown
func StrRoleToUserRole(strRole string) UserRole {
switch strings.ToLower(strRole) {
case "owner":
return UserRoleOwner
case "admin":
return UserRoleAdmin
case "user":
return UserRoleUser
case "billing_admin":
return UserRoleBillingAdmin
default:
return UserRoleUnknown
}
}
// UserStatus is the status of a User
type UserStatus string
// UserRole is the role of a User
type UserRole string
// User represents a user of the system
type User struct {
Id string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
Role UserRole
IsServiceUser bool
// NonDeletable indicates whether the service user can be deleted
NonDeletable bool
// ServiceUserName is only set if IsServiceUser is true
ServiceUserName string
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
AutoGroups []string `gorm:"serializer:json"`
PATs map[string]*PersonalAccessToken `gorm:"-"`
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"`
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
Blocked bool
// LastLogin is the last time the user logged in to IdP
LastLogin time.Time
// CreatedAt records the time the user was created
CreatedAt time.Time
// Issued of the user
Issued string `gorm:"default:api"`
IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
}
// IsBlocked returns true if the user is blocked, false otherwise
func (u *User) IsBlocked() bool {
return u.Blocked
}
func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool {
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
}
// HasAdminPower returns true if the user has admin or owner roles, false otherwise
func (u *User) HasAdminPower() bool {
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
}
// IsAdminOrServiceUser checks if the user has admin power or is a service user.
func (u *User) IsAdminOrServiceUser() bool {
return u.HasAdminPower() || u.IsServiceUser
}
// IsRegularUser checks if the user is a regular user.
func (u *User) IsRegularUser() bool {
return !u.HasAdminPower() && !u.IsServiceUser
}
// ToUserInfo converts a User object to a UserInfo object.
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
autoGroups := u.AutoGroups
if autoGroups == nil {
autoGroups = []string{}
}
dashboardViewPermissions := "full"
if !u.HasAdminPower() {
dashboardViewPermissions = "limited"
if settings.RegularUsersViewBlocked {
dashboardViewPermissions = "blocked"
}
}
if userData == nil {
return &UserInfo{
ID: u.Id,
Email: "",
Name: u.ServiceUserName,
Role: string(u.Role),
AutoGroups: u.AutoGroups,
Status: string(UserStatusActive),
IsServiceUser: u.IsServiceUser,
IsBlocked: u.Blocked,
LastLogin: u.LastLogin,
Issued: u.Issued,
Permissions: UserPermissions{
DashboardView: dashboardViewPermissions,
},
}, nil
}
if userData.ID != u.Id {
return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id)
}
userStatus := UserStatusActive
if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite {
userStatus = UserStatusInvited
}
return &UserInfo{
ID: u.Id,
Email: userData.Email,
Name: userData.Name,
Role: string(u.Role),
AutoGroups: autoGroups,
Status: string(userStatus),
IsServiceUser: u.IsServiceUser,
IsBlocked: u.Blocked,
LastLogin: u.LastLogin,
Issued: u.Issued,
Permissions: UserPermissions{
DashboardView: dashboardViewPermissions,
},
}, nil
}
// Copy the user
func (u *User) Copy() *User {
autoGroups := make([]string, len(u.AutoGroups))
copy(autoGroups, u.AutoGroups)
pats := make(map[string]*PersonalAccessToken, len(u.PATs))
for k, v := range u.PATs {
pats[k] = v.Copy()
}
return &User{
Id: u.Id,
AccountID: u.AccountID,
Role: u.Role,
AutoGroups: autoGroups,
IsServiceUser: u.IsServiceUser,
NonDeletable: u.NonDeletable,
ServiceUserName: u.ServiceUserName,
PATs: pats,
Blocked: u.Blocked,
LastLogin: u.LastLogin,
CreatedAt: u.CreatedAt,
Issued: u.Issued,
IntegrationReference: u.IntegrationReference,
}
}
// NewUser creates a new user
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
return &User{
Id: id,
Role: role,
IsServiceUser: isServiceUser,
NonDeletable: nonDeletable,
ServiceUserName: serviceUserName,
AutoGroups: autoGroups,
Issued: issued,
CreatedAt: time.Now().UTC(),
}
}
// NewRegularUser creates a new user with role UserRoleUser
func NewRegularUser(id string) *User {
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
}
// NewAdminUser creates a new user with role UserRoleAdmin
func NewAdminUser(id string) *User {
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
}
// NewOwnerUser creates a new user with role UserRoleOwner
func NewOwnerUser(id string) *User {
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI)
}
// createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role types.UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*types.UserInfo, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
@ -240,12 +41,12 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can create service users")
}
if role == UserRoleOwner {
if role == types.UserRoleOwner {
return nil, status.Errorf(status.InvalidArgument, "can't create a service user with owner role")
}
newUserID := uuid.New().String()
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI)
newUser := types.NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, types.UserIssuedAPI)
log.WithContext(ctx).Debugf("New User: %v", newUser)
account.Users[newUserID] = newUser
@ -257,29 +58,29 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI
meta := map[string]any{"name": newUser.ServiceUserName}
am.StoreEvent(ctx, initiatorUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta)
return &UserInfo{
return &types.UserInfo{
ID: newUser.Id,
Email: "",
Name: newUser.ServiceUserName,
Role: string(newUser.Role),
AutoGroups: newUser.AutoGroups,
Status: string(UserStatusActive),
Status: string(types.UserStatusActive),
IsServiceUser: true,
LastLogin: time.Time{},
Issued: UserIssuedAPI,
Issued: types.UserIssuedAPI,
}, nil
}
// CreateUser creates a new user under the given account. Effectively this is a user invite.
func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, userID string, user *UserInfo) (*UserInfo, error) {
func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, userID string, user *types.UserInfo) (*types.UserInfo, error) {
if user.IsServiceUser {
return am.createServiceUser(ctx, accountID, userID, StrRoleToUserRole(user.Role), user.Name, user.NonDeletable, user.AutoGroups)
return am.createServiceUser(ctx, accountID, userID, types.StrRoleToUserRole(user.Role), user.Name, user.NonDeletable, user.AutoGroups)
}
return am.inviteNewUser(ctx, accountID, userID, user)
}
// inviteNewUser Invites a USer to a given account and creates reference in datastore
func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *UserInfo) (*UserInfo, error) {
func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *types.UserInfo) (*types.UserInfo, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
@ -291,14 +92,14 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
return nil, fmt.Errorf("provided user update is nil")
}
invitedRole := StrRoleToUserRole(invite.Role)
invitedRole := types.StrRoleToUserRole(invite.Role)
switch {
case invite.Name == "":
return nil, status.Errorf(status.InvalidArgument, "name can't be empty")
case invite.Email == "":
return nil, status.Errorf(status.InvalidArgument, "email can't be empty")
case invitedRole == UserRoleOwner:
case invitedRole == types.UserRoleOwner:
return nil, status.Errorf(status.InvalidArgument, "can't invite a user with owner role")
default:
}
@ -348,7 +149,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
return nil, err
}
newUser := &User{
newUser := &types.User{
Id: idpUser.ID,
Role: invitedRole,
AutoGroups: invite.AutoGroups,
@ -373,19 +174,19 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
return newUser.ToUserInfo(idpUser, account.Settings)
}
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) {
return am.Store.GetUserByUserID(ctx, LockingStrengthShare, id)
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id)
}
// GetUser looks up a user by provided authorization claims.
// It will also create an account if didn't exist for this user before.
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) {
accountID, userID, err := am.GetAccountIDFromToken(ctx, claims)
if err != nil {
return nil, fmt.Errorf("failed to get account with token claims %v", err)
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return nil, err
}
@ -409,7 +210,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
// ListUsers returns lists of all users under the account.
// It doesn't populate user information such as email or name.
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*User, error) {
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
@ -418,7 +219,7 @@ func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string
return nil, err
}
users := make([]*User, 0, len(account.Users))
users := make([]*types.User, 0, len(account.Users))
for _, item := range account.Users {
users = append(users, item)
}
@ -426,7 +227,7 @@ func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string
return users, nil
}
func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, account *Account, initiatorUserID string, targetUser *User) {
func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, account *types.Account, initiatorUserID string, targetUser *types.User) {
meta := map[string]any{"name": targetUser.ServiceUserName, "created_at": targetUser.CreatedAt}
am.StoreEvent(ctx, initiatorUserID, targetUser.Id, account.Id, activity.ServiceUserDeleted, meta)
delete(account.Users, targetUser.Id)
@ -458,12 +259,12 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
return status.Errorf(status.NotFound, "target user not found")
}
if targetUser.Role == UserRoleOwner {
if targetUser.Role == types.UserRoleOwner {
return status.Errorf(status.PermissionDenied, "unable to delete a user with owner role")
}
// disable deleting integration user if the initiator is not admin service user
if targetUser.Issued == UserIssuedIntegration && !executingUser.IsServiceUser {
if targetUser.Issued == types.UserIssuedIntegration && !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only integration service user can delete this user")
}
@ -480,7 +281,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
return am.deleteRegularUser(ctx, account, initiatorUserID, targetUserID)
}
func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error {
func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *types.Account, initiatorUserID, targetUserID string) error {
meta, updateAccountPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
if err != nil {
return err
@ -500,7 +301,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account
return nil
}
func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) (bool, error) {
func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *types.Account) (bool, error) {
peers, err := account.FindUserPeers(targetUserID)
if err != nil {
return false, status.Errorf(status.Internal, "failed to find user peers")
@ -560,7 +361,7 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
}
// CreatePAT creates a new PAT for the given user
func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
@ -591,7 +392,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
}
pat, err := CreateNewPAT(tokenName, expiresIn, executingUser.Id)
pat, err := types.CreateNewPAT(tokenName, expiresIn, executingUser.Id)
if err != nil {
return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err)
}
@ -660,13 +461,13 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
}
// GetPAT returns a specific PAT from a user
func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) {
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
if err != nil {
return nil, err
}
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
if err != nil {
return nil, err
}
@ -685,13 +486,13 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
}
// GetAllPATs returns all PATs for a user
func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) {
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID)
if err != nil {
return nil, err
}
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID)
if err != nil {
return nil, err
}
@ -700,7 +501,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
}
pats := make([]*PersonalAccessToken, 0, len(targetUser.PATsG))
pats := make([]*types.PersonalAccessToken, 0, len(targetUser.PATsG))
for _, pat := range targetUser.PATsG {
pats = append(pats, pat.Copy())
}
@ -709,13 +510,13 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
}
// SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error.
func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) {
func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error) {
return am.SaveOrAddUser(ctx, accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound
}
// SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist
// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now.
func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) {
func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) {
if update == nil {
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
}
@ -723,7 +524,7 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists)
updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*types.User{update}, addIfNotExists)
if err != nil {
return nil, err
}
@ -738,7 +539,7 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i
// SaveOrAddUsers updates existing users or adds new users to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) {
func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) {
if len(updates) == 0 {
return nil, nil //nolint:nilnil
}
@ -757,7 +558,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations")
}
updatedUsers := make([]*UserInfo, 0, len(updates))
updatedUsers := make([]*types.UserInfo, 0, len(updates))
var (
expiredPeers []*nbpeer.Peer
userIDs []string
@ -808,7 +609,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
peerGroupsAdded := make(map[string][]string)
peerGroupsRemoved := make(map[string][]string)
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups)
// need force update all auto groups in any case they will not be duplicated
peerGroupsAdded = account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
peerGroupsRemoved = account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
@ -851,7 +652,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
}
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, transferredOwnerRole bool) []func() {
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, transferredOwnerRole bool) []func() {
var eventsToStore []func()
if oldUser.IsBlocked() != newUser.IsBlocked() {
@ -880,11 +681,11 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in
return eventsToStore
}
func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() {
func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() {
var eventsToStore []func()
if newUser.AutoGroups != nil {
removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups)
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
removedGroups := util.Difference(oldUser.AutoGroups, newUser.AutoGroups)
addedGroups := util.Difference(newUser.AutoGroups, oldUser.AutoGroups)
removedEvents := am.handleGroupRemovedFromUser(ctx, initiatorUserID, oldUser, newUser, account, removedGroups, peerGroupsRemoved)
eventsToStore = append(eventsToStore, removedEvents...)
@ -895,7 +696,7 @@ func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, in
return eventsToStore
}
func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() {
func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() {
var eventsToStore []func()
for _, g := range addedGroups {
group := account.GetGroup(g)
@ -922,7 +723,7 @@ func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, ini
return eventsToStore
}
func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() {
func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() {
var eventsToStore []func()
for _, g := range removedGroups {
group := account.GetGroup(g)
@ -952,10 +753,10 @@ func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context,
return eventsToStore
}
func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool {
if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner {
func handleOwnerRoleTransfer(account *types.Account, initiatorUser, update *types.User) bool {
if initiatorUser.Role == types.UserRoleOwner && initiatorUser.Id != update.Id && update.Role == types.UserRoleOwner {
newInitiatorUser := initiatorUser.Copy()
newInitiatorUser.Role = UserRoleAdmin
newInitiatorUser.Role = types.UserRoleAdmin
account.Users[initiatorUser.Id] = newInitiatorUser
return true
}
@ -965,7 +766,7 @@ func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool
// getUserInfo retrieves the UserInfo for a given User and Account.
// If the AccountManager has a non-nil idpManager and the User is not a service user,
// it will attempt to look up the UserData from the cache.
func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, account *Account) (*UserInfo, error) {
func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *types.User, account *types.Account) (*types.UserInfo, error) {
if !isNil(am.idpManager) && !user.IsServiceUser {
userData, err := am.lookupUserInCache(ctx, user.Id, account)
if err != nil {
@ -977,23 +778,23 @@ func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, acc
}
// validateUserUpdate validates the update operation for a user.
func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) error {
func validateUserUpdate(account *types.Account, initiatorUser, oldUser, update *types.User) error {
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
}
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && update.Role != initiatorUser.Role {
return status.Errorf(status.PermissionDenied, "admins can't change their role")
}
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.Role != oldUser.Role {
if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.Role != oldUser.Role {
return status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user")
}
if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() {
return status.Errorf(status.PermissionDenied, "unable to block owner user")
}
if initiatorUser.Role == UserRoleAdmin && update.Role == UserRoleOwner && update.Role != oldUser.Role {
if initiatorUser.Role == types.UserRoleAdmin && update.Role == types.UserRoleOwner && update.Role != oldUser.Role {
return status.Errorf(status.PermissionDenied, "only owners can add owner role to other users")
}
if oldUser.IsServiceUser && update.Role == UserRoleOwner {
if oldUser.IsServiceUser && update.Role == types.UserRoleOwner {
return status.Errorf(status.PermissionDenied, "can't update a service user with owner role")
}
@ -1012,7 +813,7 @@ func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User)
}
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*Account, error) {
func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*types.Account, error) {
start := time.Now()
unlock := am.Store.AcquireGlobalLock(ctx)
defer unlock()
@ -1039,7 +840,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u
userObj := account.Users[userID]
if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == UserRoleOwner {
if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == types.UserRoleOwner {
account.Domain = lowerDomain
err = am.Store.SaveAccount(ctx, account)
if err != nil {
@ -1052,7 +853,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
// based on provided user role.
func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) {
func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) {
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
@ -1068,7 +869,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
users := make(map[string]userLoggedInOnce, len(account.Users))
usersFromIntegration := make([]*idp.UserData, 0)
for _, user := range account.Users {
if user.Issued == UserIssuedIntegration {
if user.Issued == types.UserIssuedIntegration {
key := user.IntegrationReference.CacheKey(accountID, user.Id)
info, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil {
@ -1092,7 +893,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
queriedUsers = append(queriedUsers, usersFromIntegration...)
}
userInfos := make([]*UserInfo, 0)
userInfos := make([]*types.UserInfo, 0)
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
if len(queriedUsers) == 0 {
@ -1116,7 +917,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
continue
}
var info *UserInfo
var info *types.UserInfo
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
info, err = localUser.ToUserInfo(queriedUser, account.Settings)
if err != nil {
@ -1136,16 +937,16 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
}
}
info = &UserInfo{
info = &types.UserInfo{
ID: localUser.Id,
Email: "",
Name: name,
Role: string(localUser.Role),
AutoGroups: localUser.AutoGroups,
Status: string(UserStatusActive),
Status: string(types.UserStatusActive),
IsServiceUser: localUser.IsServiceUser,
NonDeletable: localUser.NonDeletable,
Permissions: UserPermissions{DashboardView: dashboardViewPermissions},
Permissions: types.UserPermissions{DashboardView: dashboardViewPermissions},
}
}
userInfos = append(userInfos, info)
@ -1155,7 +956,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
}
// expireAndUpdatePeers expires all peers of the given user and updates them in the account
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error {
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *types.Account, peers []*nbpeer.Peer) error {
var peerIDs []string
for _, peer := range peers {
// nolint:staticcheck
@ -1260,13 +1061,13 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
continue
}
if targetUser.Role == UserRoleOwner {
if targetUser.Role == types.UserRoleOwner {
allErrors = errors.Join(allErrors, fmt.Errorf("unable to delete a user: %s with owner role", targetUserID))
continue
}
// disable deleting integration user if the initiator is not admin service user
if targetUser.Issued == UserIssuedIntegration && !executingUser.IsServiceUser {
if targetUser.Issued == types.UserIssuedIntegration && !executingUser.IsServiceUser {
allErrors = errors.Join(allErrors, errors.New("only integration service user can delete this user"))
continue
}
@ -1301,7 +1102,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
return allErrors
}
func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) {
func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *types.Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) {
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve email address: %s", err)
@ -1419,7 +1220,7 @@ func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserDa
}
// areUsersLinkedToPeers checks if any of the given userIDs are linked to any of the peers in the account.
func areUsersLinkedToPeers(account *Account, userIDs []string) bool {
func areUsersLinkedToPeers(account *types.Account, userIDs []string) bool {
for _, peer := range account.Peers {
if slices.Contains(userIDs, peer.UserID) {
return true

View File

@ -10,8 +10,12 @@ import (
"github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store"
"github.com/google/go-cmp/cmp"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
gocache "github.com/patrickmn/go-cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -41,11 +45,15 @@ const (
)
func TestUser_CreatePAT_ForSameUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -82,14 +90,18 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
}
func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockTargetUserId] = &User{
account.Users[mockTargetUserId] = &types.User{
Id: mockTargetUserId,
IsServiceUser: false,
}
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -104,14 +116,18 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
}
func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockTargetUserId] = &User{
account.Users[mockTargetUserId] = &types.User{
Id: mockTargetUserId,
IsServiceUser: true,
}
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -130,11 +146,15 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
}
func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -149,11 +169,15 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
}
func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -168,19 +192,23 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
}
func TestUser_DeletePAT(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{
account.Users[mockUserID] = &types.User{
Id: mockUserID,
PATs: map[string]*PersonalAccessToken{
PATs: map[string]*types.PersonalAccessToken{
mockTokenID1: {
ID: mockTokenID1,
HashedToken: mockToken1,
},
},
}
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -204,20 +232,24 @@ func TestUser_DeletePAT(t *testing.T) {
}
func TestUser_GetPAT(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{
account.Users[mockUserID] = &types.User{
Id: mockUserID,
AccountID: mockAccountID,
PATs: map[string]*PersonalAccessToken{
PATs: map[string]*types.PersonalAccessToken{
mockTokenID1: {
ID: mockTokenID1,
HashedToken: mockToken1,
},
},
}
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -237,13 +269,17 @@ func TestUser_GetPAT(t *testing.T) {
}
func TestUser_GetAllPATs(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{
account.Users[mockUserID] = &types.User{
Id: mockUserID,
AccountID: mockAccountID,
PATs: map[string]*PersonalAccessToken{
PATs: map[string]*types.PersonalAccessToken{
mockTokenID1: {
ID: mockTokenID1,
HashedToken: mockToken1,
@ -254,7 +290,7 @@ func TestUser_GetAllPATs(t *testing.T) {
},
},
}
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -274,14 +310,14 @@ func TestUser_GetAllPATs(t *testing.T) {
func TestUser_Copy(t *testing.T) {
// this is an imaginary case which will never be in DB this way
user := User{
user := types.User{
Id: "userId",
AccountID: "accountId",
Role: "role",
IsServiceUser: true,
ServiceUserName: "servicename",
AutoGroups: []string{"group1", "group2"},
PATs: map[string]*PersonalAccessToken{
PATs: map[string]*types.PersonalAccessToken{
"pat1": {
ID: "pat1",
Name: "First PAT",
@ -340,11 +376,15 @@ func validateStruct(s interface{}) (err error) {
}
func TestUser_CreateServiceUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -366,26 +406,30 @@ func TestUser_CreateServiceUser(t *testing.T) {
assert.NotNil(t, account.Users[user.ID])
assert.True(t, account.Users[user.ID].IsServiceUser)
assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName)
assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role)
assert.Equal(t, types.UserRole(mockRole), account.Users[user.ID].Role)
assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups)
assert.Equal(t, map[string]*PersonalAccessToken{}, account.Users[user.ID].PATs)
assert.Equal(t, map[string]*types.PersonalAccessToken{}, account.Users[user.ID].PATs)
assert.Zero(t, user.Email)
assert.True(t, user.IsServiceUser)
assert.Equal(t, "active", user.Status)
_, err = am.createServiceUser(context.Background(), mockAccountID, mockUserID, UserRoleOwner, mockServiceUserName, false, nil)
_, err = am.createServiceUser(context.Background(), mockAccountID, mockUserID, types.UserRoleOwner, mockServiceUserName, false, nil)
if err == nil {
t.Fatal("should return error when creating service user with owner role")
}
}
func TestUser_CreateUser_ServiceUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -395,7 +439,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
eventStore: &activity.InMemoryEventStore{},
}
user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{
user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{
Name: mockServiceUserName,
Role: mockRole,
IsServiceUser: true,
@ -413,7 +457,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
assert.Equal(t, 2, len(account.Users))
assert.True(t, account.Users[user.ID].IsServiceUser)
assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName)
assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role)
assert.Equal(t, types.UserRole(mockRole), account.Users[user.ID].Role)
assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups)
assert.Equal(t, mockServiceUserName, user.Name)
@ -423,11 +467,15 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
}
func TestUser_CreateUser_RegularUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -437,7 +485,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
eventStore: &activity.InMemoryEventStore{},
}
_, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{
_, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{
Name: mockServiceUserName,
Role: mockRole,
IsServiceUser: false,
@ -448,11 +496,15 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
}
func TestUser_InviteNewUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -495,7 +547,7 @@ func TestUser_InviteNewUser(t *testing.T) {
am.idpManager = &idpMock
// test if new invite with regular role works
_, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &UserInfo{
_, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{
Name: mockServiceUserName,
Role: mockRole,
Email: "test@teste.com",
@ -506,9 +558,9 @@ func TestUser_InviteNewUser(t *testing.T) {
assert.NoErrorf(t, err, "Invite user should not throw error")
// test if new invite with owner role fails
_, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &UserInfo{
_, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{
Name: mockServiceUserName,
Role: string(UserRoleOwner),
Role: string(types.UserRoleOwner),
Email: "test2@teste.com",
IsServiceUser: false,
AutoGroups: []string{"group1", "group2"},
@ -520,13 +572,13 @@ func TestUser_InviteNewUser(t *testing.T) {
func TestUser_DeleteUser_ServiceUser(t *testing.T) {
tests := []struct {
name string
serviceUser *User
serviceUser *types.User
assertErrFunc assert.ErrorAssertionFunc
assertErrMessage string
}{
{
name: "Can delete service user",
serviceUser: &User{
serviceUser: &types.User{
Id: mockServiceUserID,
IsServiceUser: true,
ServiceUserName: mockServiceUserName,
@ -535,7 +587,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
},
{
name: "Cannot delete non-deletable service user",
serviceUser: &User{
serviceUser: &types.User{
Id: mockServiceUserID,
IsServiceUser: true,
ServiceUserName: mockServiceUserName,
@ -548,11 +600,16 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := newStore(t)
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = tt.serviceUser
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -580,11 +637,15 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
}
func TestUser_DeleteUser_SelfDelete(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -601,38 +662,42 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
}
func TestUser_DeleteUser_regularUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
targetId := "user2"
account.Users[targetId] = &User{
account.Users[targetId] = &types.User{
Id: targetId,
IsServiceUser: true,
ServiceUserName: "user2username",
}
targetId = "user3"
account.Users[targetId] = &User{
account.Users[targetId] = &types.User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
Issued: types.UserIssuedAPI,
}
targetId = "user4"
account.Users[targetId] = &User{
account.Users[targetId] = &types.User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedIntegration,
Issued: types.UserIssuedIntegration,
}
targetId = "user5"
account.Users[targetId] = &User{
account.Users[targetId] = &types.User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleOwner,
Issued: types.UserIssuedAPI,
Role: types.UserRoleOwner,
}
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -683,60 +748,64 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
}
func TestUser_DeleteUser_RegularUsers(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
targetId := "user2"
account.Users[targetId] = &User{
account.Users[targetId] = &types.User{
Id: targetId,
IsServiceUser: true,
ServiceUserName: "user2username",
}
targetId = "user3"
account.Users[targetId] = &User{
account.Users[targetId] = &types.User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
Issued: types.UserIssuedAPI,
}
targetId = "user4"
account.Users[targetId] = &User{
account.Users[targetId] = &types.User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedIntegration,
Issued: types.UserIssuedIntegration,
}
targetId = "user5"
account.Users[targetId] = &User{
account.Users[targetId] = &types.User{
Id: targetId,
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleOwner,
Issued: types.UserIssuedAPI,
Role: types.UserRoleOwner,
}
account.Users["user6"] = &User{
account.Users["user6"] = &types.User{
Id: "user6",
IsServiceUser: false,
Issued: UserIssuedAPI,
Issued: types.UserIssuedAPI,
}
account.Users["user7"] = &User{
account.Users["user7"] = &types.User{
Id: "user7",
IsServiceUser: false,
Issued: UserIssuedAPI,
Issued: types.UserIssuedAPI,
}
account.Users["user8"] = &User{
account.Users["user8"] = &types.User{
Id: "user8",
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleAdmin,
Issued: types.UserIssuedAPI,
Role: types.UserRoleAdmin,
}
account.Users["user9"] = &User{
account.Users["user9"] = &types.User{
Id: "user9",
IsServiceUser: false,
Issued: UserIssuedAPI,
Role: UserRoleAdmin,
Issued: types.UserIssuedAPI,
Role: types.UserRoleAdmin,
}
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -834,11 +903,15 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
}
func TestDefaultAccountManager_GetUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -863,13 +936,17 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
}
func TestDefaultAccountManager_ListUsers(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewRegularUser("normal_user1")
account.Users["normal_user2"] = NewRegularUser("normal_user2")
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
err := store.SaveAccount(context.Background(), account)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users["normal_user1"] = types.NewRegularUser("normal_user1")
account.Users["normal_user2"] = types.NewRegularUser("normal_user2")
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -901,43 +978,43 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) {
func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
testCases := []struct {
name string
role UserRole
role types.UserRole
limitedViewSettings bool
expectedDashboardPermissions string
}{
{
name: "Regular user, no limited view settings",
role: UserRoleUser,
role: types.UserRoleUser,
limitedViewSettings: false,
expectedDashboardPermissions: "limited",
},
{
name: "Admin user, no limited view settings",
role: UserRoleAdmin,
role: types.UserRoleAdmin,
limitedViewSettings: false,
expectedDashboardPermissions: "full",
},
{
name: "Owner, no limited view settings",
role: UserRoleOwner,
role: types.UserRoleOwner,
limitedViewSettings: false,
expectedDashboardPermissions: "full",
},
{
name: "Regular user, limited view settings",
role: UserRoleUser,
role: types.UserRoleUser,
limitedViewSettings: true,
expectedDashboardPermissions: "blocked",
},
{
name: "Admin user, limited view settings",
role: UserRoleAdmin,
role: types.UserRoleAdmin,
limitedViewSettings: true,
expectedDashboardPermissions: "full",
},
{
name: "Owner, limited view settings",
role: UserRoleOwner,
role: types.UserRoleOwner,
limitedViewSettings: true,
expectedDashboardPermissions: "full",
},
@ -945,13 +1022,18 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
store := newStore(t)
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI)
account.Users["normal_user1"] = types.NewUser("normal_user1", testCase.role, false, false, "", []string{}, types.UserIssuedAPI)
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
delete(account.Users, mockUserID)
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -976,13 +1058,17 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) {
}
func TestDefaultAccountManager_ExternalCache(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
externalUser := &User{
externalUser := &types.User{
Id: "externalUser",
Role: UserRoleUser,
Issued: UserIssuedIntegration,
Role: types.UserRoleUser,
Issued: types.UserIssuedIntegration,
IntegrationReference: integration_reference.IntegrationReference{
ID: 1,
IntegrationType: "external",
@ -990,7 +1076,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
}
account.Users[externalUser.Id] = externalUser
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -1020,7 +1106,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
infos, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID)
assert.NoError(t, err)
assert.Equal(t, 2, len(infos))
var user *UserInfo
var user *types.UserInfo
for _, info := range infos {
if info.ID == externalUser.Id {
user = info
@ -1032,24 +1118,28 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
func TestUser_IsAdmin(t *testing.T) {
user := NewAdminUser(mockUserID)
user := types.NewAdminUser(mockUserID)
assert.True(t, user.HasAdminPower())
user = NewRegularUser(mockUserID)
user = types.NewRegularUser(mockUserID)
assert.False(t, user.HasAdminPower())
}
func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = &User{
account.Users[mockServiceUserID] = &types.User{
Id: mockServiceUserID,
Role: "user",
IsServiceUser: true,
}
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -1068,17 +1158,20 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
}
func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
store := newStore(t)
defer store.Close(context.Background())
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
if err != nil {
t.Fatalf("Error when creating store: %s", err)
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockServiceUserID] = &User{
account.Users[mockServiceUserID] = &types.User{
Id: mockServiceUserID,
Role: "user",
IsServiceUser: true,
}
err := store.SaveAccount(context.Background(), account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -1112,25 +1205,25 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
tt := []struct {
name string
initiatorID string
update *User
update *types.User
expectedErr bool
}{
{
name: "Should_Fail_To_Update_Admin_Role",
expectedErr: true,
initiatorID: adminUserID,
update: &User{
update: &types.User{
Id: adminUserID,
Role: UserRoleUser,
Role: types.UserRoleUser,
Blocked: false,
},
}, {
name: "Should_Fail_When_Admin_Blocks_Themselves",
expectedErr: true,
initiatorID: adminUserID,
update: &User{
update: &types.User{
Id: adminUserID,
Role: UserRoleAdmin,
Role: types.UserRoleAdmin,
Blocked: true,
},
},
@ -1138,9 +1231,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
name: "Should_Fail_To_Update_Non_Existing_User",
expectedErr: true,
initiatorID: adminUserID,
update: &User{
update: &types.User{
Id: userID,
Role: UserRoleAdmin,
Role: types.UserRoleAdmin,
Blocked: true,
},
},
@ -1148,9 +1241,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
name: "Should_Fail_To_Update_When_Initiator_Is_Not_An_Admin",
expectedErr: true,
initiatorID: regularUserID,
update: &User{
update: &types.User{
Id: adminUserID,
Role: UserRoleAdmin,
Role: types.UserRoleAdmin,
Blocked: true,
},
},
@ -1158,9 +1251,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
name: "Should_Update_User",
expectedErr: false,
initiatorID: adminUserID,
update: &User{
update: &types.User{
Id: regularUserID,
Role: UserRoleAdmin,
Role: types.UserRoleAdmin,
Blocked: true,
},
},
@ -1168,9 +1261,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
name: "Should_Transfer_Owner_Role_To_User",
expectedErr: false,
initiatorID: ownerUserID,
update: &User{
update: &types.User{
Id: adminUserID,
Role: UserRoleAdmin,
Role: types.UserRoleAdmin,
Blocked: false,
},
},
@ -1178,9 +1271,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
name: "Should_Fail_To_Transfer_Owner_Role_To_Service_User",
expectedErr: true,
initiatorID: ownerUserID,
update: &User{
update: &types.User{
Id: serviceUserID,
Role: UserRoleOwner,
Role: types.UserRoleOwner,
Blocked: false,
},
},
@ -1188,9 +1281,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
name: "Should_Fail_To_Update_Owner_User_Role_By_Admin",
expectedErr: true,
initiatorID: adminUserID,
update: &User{
update: &types.User{
Id: ownerUserID,
Role: UserRoleAdmin,
Role: types.UserRoleAdmin,
Blocked: false,
},
},
@ -1198,9 +1291,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
name: "Should_Fail_To_Update_Owner_User_Role_By_User",
expectedErr: true,
initiatorID: regularUserID,
update: &User{
update: &types.User{
Id: ownerUserID,
Role: UserRoleAdmin,
Role: types.UserRoleAdmin,
Blocked: false,
},
},
@ -1208,9 +1301,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
name: "Should_Fail_To_Update_Owner_User_Role_By_Service_User",
expectedErr: true,
initiatorID: serviceUserID,
update: &User{
update: &types.User{
Id: ownerUserID,
Role: UserRoleAdmin,
Role: types.UserRoleAdmin,
Blocked: false,
},
},
@ -1218,9 +1311,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
name: "Should_Fail_To_Update_Owner_Role_By_Admin",
expectedErr: true,
initiatorID: adminUserID,
update: &User{
update: &types.User{
Id: regularUserID,
Role: UserRoleOwner,
Role: types.UserRoleOwner,
Blocked: false,
},
},
@ -1228,9 +1321,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
name: "Should_Fail_To_Block_Owner_Role_By_Admin",
expectedErr: true,
initiatorID: adminUserID,
update: &User{
update: &types.User{
Id: ownerUserID,
Role: UserRoleOwner,
Role: types.UserRoleOwner,
Blocked: true,
},
},
@ -1246,9 +1339,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
}
// create other users
account.Users[regularUserID] = NewRegularUser(regularUserID)
account.Users[adminUserID] = NewAdminUser(adminUserID)
account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"}
account.Users[regularUserID] = types.NewRegularUser(regularUserID)
account.Users[adminUserID] = types.NewAdminUser(adminUserID)
account.Users[serviceUserID] = &types.User{IsServiceUser: true, Id: serviceUserID, Role: types.UserRoleAdmin, ServiceUserName: "service"}
err = manager.Store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatal(err)
@ -1279,15 +1372,15 @@ func TestUserAccountPeersUpdate(t *testing.T) {
})
require.NoError(t, err)
policy := &Policy{
policy := &types.Policy{
Enabled: true,
Rules: []*PolicyRule{
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
Action: types.PolicyTrafficActionAccept,
},
},
}
@ -1307,11 +1400,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
close(done)
}()
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{
Id: "regularUser1",
AccountID: account.Id,
Role: UserRoleUser,
Issued: UserIssuedAPI,
Role: types.UserRoleUser,
Issued: types.UserIssuedAPI,
}, true)
require.NoError(t, err)
@ -1330,11 +1423,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
close(done)
}()
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{
Id: "regularUser1",
AccountID: account.Id,
Role: UserRoleUser,
Issued: UserIssuedAPI,
Role: types.UserRoleUser,
Issued: types.UserIssuedAPI,
}, false)
require.NoError(t, err)
@ -1364,11 +1457,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
})
// create a user and add new peer with the user
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{
Id: "regularUser2",
AccountID: account.Id,
Role: UserRoleAdmin,
Issued: UserIssuedAPI,
Role: types.UserRoleAdmin,
Issued: types.UserIssuedAPI,
}, true)
require.NoError(t, err)
@ -1390,11 +1483,11 @@ func TestUserAccountPeersUpdate(t *testing.T) {
close(done)
}()
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
_, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{
Id: "regularUser2",
AccountID: account.Id,
Role: UserRoleAdmin,
Issued: UserIssuedAPI,
Role: types.UserRoleAdmin,
Issued: types.UserIssuedAPI,
}, false)
require.NoError(t, err)

View File

@ -0,0 +1,16 @@
package util
// Difference returns the elements in `a` that aren't in `b`.
func Difference(a, b []string) []string {
mb := make(map[string]struct{}, len(b))
for _, x := range b {
mb[x] = struct{}{}
}
var diff []string
for _, x := range a {
if _, found := mb[x]; !found {
diff = append(diff, x)
}
}
return diff
}