mirror of
https://github.com/netbirdio/netbird.git
synced 2025-04-11 21:18:57 +02:00
[management] Add AccountExists to AccountManager (#2694)
* Add AccountExists method to account manager interface Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * remove unused code Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> --------- Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
7f09b39769
commit
8bf729c7b4
@ -76,6 +76,7 @@ type AccountManager interface {
|
||||
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
|
||||
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
|
||||
GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
|
||||
AccountExists(ctx context.Context, accountID string) (bool, error)
|
||||
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
||||
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||
@ -1261,6 +1262,11 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
return nil
|
||||
}
|
||||
|
||||
// AccountExists checks if an account exists.
|
||||
func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
|
||||
return am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// GetAccountIDByUserID retrieves the account ID based on the userID provided.
|
||||
// If user does have an account, it returns the user's account ID.
|
||||
// If the user doesn't have an account, it creates one using the provided domain.
|
||||
|
@ -27,6 +27,7 @@ type MockAccountManager struct {
|
||||
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
||||
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
|
||||
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
|
||||
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
||||
@ -58,7 +59,7 @@ type MockAccountManager struct {
|
||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error
|
||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups,accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
|
||||
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
||||
@ -194,6 +195,14 @@ func (am *MockAccountManager) CreateSetupKey(
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
|
||||
}
|
||||
|
||||
// AccountExists mock implementation of AccountExists from server.AccountManager interface
|
||||
func (am *MockAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
|
||||
if am.GetAccountIDByUserIdFunc != nil {
|
||||
return am.AccountExistsFunc(ctx, accountID)
|
||||
}
|
||||
return false, status.Errorf(codes.Unimplemented, "method AccountExists is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) {
|
||||
if am.GetAccountIDByUserIdFunc != nil {
|
||||
@ -444,7 +453,7 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID
|
||||
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
|
||||
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
||||
if am.CreateRouteFunc != nil {
|
||||
return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups,accessControlGroupID, enabled, userID, keepRoute)
|
||||
return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
|
||||
}
|
||||
|
@ -10,7 +10,6 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -1033,84 +1032,6 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddUserPeersToGroups adds the user's peers to specified groups in database.
|
||||
func (s *SqlStore) AddUserPeersToGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var userPeerIDs []string
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id").
|
||||
Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "issue finding user peers")
|
||||
}
|
||||
|
||||
groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs))
|
||||
for _, gid := range groupIDs {
|
||||
group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
groupPeers := make(map[string]struct{})
|
||||
for _, pid := range group.Peers {
|
||||
groupPeers[pid] = struct{}{}
|
||||
}
|
||||
|
||||
for _, pid := range userPeerIDs {
|
||||
groupPeers[pid] = struct{}{}
|
||||
}
|
||||
|
||||
group.Peers = group.Peers[:0]
|
||||
for pid := range groupPeers {
|
||||
group.Peers = append(group.Peers, pid)
|
||||
}
|
||||
|
||||
groupsToUpdate = append(groupsToUpdate, group)
|
||||
}
|
||||
|
||||
return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate)
|
||||
}
|
||||
|
||||
// RemoveUserPeersFromGroups removes the user's peers from specified groups in database.
|
||||
func (s *SqlStore) RemoveUserPeersFromGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var userPeerIDs []string
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id").
|
||||
Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "issue finding user peers")
|
||||
}
|
||||
|
||||
groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs))
|
||||
for _, gid := range groupIDs {
|
||||
group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if group.Name == "All" {
|
||||
continue
|
||||
}
|
||||
|
||||
update := make([]string, 0, len(group.Peers))
|
||||
for _, pid := range group.Peers {
|
||||
if !slices.Contains(userPeerIDs, pid) {
|
||||
update = append(update, pid)
|
||||
}
|
||||
}
|
||||
|
||||
group.Peers = update
|
||||
groupsToUpdate = append(groupsToUpdate, group)
|
||||
}
|
||||
|
||||
return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate)
|
||||
}
|
||||
|
||||
// GetUserPeers retrieves peers for a user.
|
||||
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID)
|
||||
|
Loading…
Reference in New Issue
Block a user