refactor account and dns settings

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-10-01 00:54:28 +03:00
parent 9e47c94a7f
commit 43eb7261e3
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
4 changed files with 97 additions and 76 deletions

View File

@ -128,7 +128,7 @@ type AccountManager interface {
GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error)
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error)
LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error) GetAllConnectedPeers() (map[string]struct{}, error)
@ -1048,7 +1048,16 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
// Only users with role UserRoleAdmin can update the account. // Only users with role UserRoleAdmin can update the account.
// User that performs the update has to belong to the account. // User that performs the update has to belong to the account.
// Returns an updated Account // Returns an updated Account
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) { func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account")
}
halfYearLimit := 180 * 24 * time.Hour halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit { if newSettings.PeerLoginExpiration > halfYearLimit {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@ -1058,53 +1067,57 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
} }
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) oldSettings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) if err = am.validateExtraSettings(ctx, newSettings, oldSettings, userID, accountID); err != nil {
if err != nil {
return nil, err return nil, err
} }
if !user.HasAdminPower() { if err = am.Store.SaveAccountSettings(ctx, LockingStrengthUpdate, accountID, newSettings); err != nil {
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") return nil, fmt.Errorf("failed updating account settings: %w", err)
} }
err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID)
if err != nil {
return nil, err
}
oldSettings := account.Settings
if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled { if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled {
event := activity.AccountPeerLoginExpirationEnabled event := activity.AccountPeerLoginExpirationEnabled
if !newSettings.PeerLoginExpirationEnabled { if !newSettings.PeerLoginExpirationEnabled {
event = activity.AccountPeerLoginExpirationDisabled event = activity.AccountPeerLoginExpirationDisabled
am.peerLoginExpiry.Cancel(ctx, []string{accountID}) am.peerLoginExpiry.Cancel(ctx, []string{accountID})
} else { } else {
am.checkAndSchedulePeerLoginExpiration(ctx, account) am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
} }
am.StoreEvent(ctx, userID, accountID, accountID, event, nil) am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
} }
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
am.checkAndSchedulePeerLoginExpiration(ctx, account) am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
} }
updatedAccount := account.UpdateSettings(newSettings) account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
err = am.Store.SaveAccount(ctx, account)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
return newSettings, nil
}
// validateExtraSettings validates the extra settings of the account.
func (am *DefaultAccountManager) validateExtraSettings(ctx context.Context, newSettings, oldSettings *Settings, userID, accountID string) error {
peers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
} }
return updatedAccount, nil peerMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peerMap[peer.ID] = peer
}
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peerMap, userID, accountID)
} }
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
@ -1135,10 +1148,10 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
} }
} }
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *Account) { func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) {
am.peerLoginExpiry.Cancel(ctx, []string{account.Id}) am.peerLoginExpiry.Cancel(ctx, []string{accountID})
if nextRun, ok := account.GetNextPeerExpiration(); ok { if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok {
go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id)) go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID))
} }
} }
@ -1674,33 +1687,18 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
// MarkPATUsed marks a personal access token as used // MarkPATUsed marks a personal access token as used
func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error { func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error {
user, err := am.Store.GetUserByTokenID(ctx, tokenID) user, err := am.Store.GetUserByTokenID(ctx, tokenID)
if err != nil { if err != nil {
return err return err
} }
account, err := am.Store.GetAccountByUser(ctx, user.Id) pat, err := am.Store.GetPATByID(ctx, LockingStrengthShare, tokenID, user.Id)
if err != nil { if err != nil {
return err return err
} }
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()
account, err = am.Store.GetAccountByUser(ctx, user.Id)
if err != nil {
return err
}
pat, ok := account.Users[user.Id].PATs[tokenID]
if !ok {
return fmt.Errorf("token not found")
}
pat.LastUsed = time.Now().UTC() pat.LastUsed = time.Now().UTC()
return am.Store.SaveAccount(ctx, account) return am.Store.SavePAT(ctx, LockingStrengthUpdate, pat)
} }
// GetAccount returns an account associated with this account ID. // GetAccount returns an account associated with this account ID.

View File

@ -6,6 +6,7 @@ import (
"strconv" "strconv"
"sync" "sync"
nbgroup "github.com/netbirdio/netbird/management/server/group"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@ -94,56 +95,78 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
// SaveDNSSettings validates a user role and updates the account's DNS settings // SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := account.FindUser(userID)
if err != nil {
return err
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
}
if dnsSettingsToSave == nil { if dnsSettingsToSave == nil {
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
} }
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if !user.HasAdminPower() || user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
}
oldSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return err
}
groups, err := am.Store.GetAccountGroups(ctx, accountID)
if err != nil {
return err
}
if len(dnsSettingsToSave.DisabledManagementGroups) != 0 { if len(dnsSettingsToSave.DisabledManagementGroups) != 0 {
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups) err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, groups)
if err != nil { if err != nil {
return err return err
} }
} }
oldSettings := account.DNSSettings.Copy() err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
account.DNSSettings = dnsSettingsToSave.Copy() if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
account.Network.IncSerial() if err = transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave); err != nil {
if err = am.Store.SaveAccount(ctx, account); err != nil { return fmt.Errorf("failed to update dns settings: %w", err)
}
return nil
})
if err != nil {
return err return err
} }
groupMap := make(map[string]*nbgroup.Group, len(groups))
for _, g := range groups {
groupMap[g.ID] = g
}
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
for _, id := range addedGroups { for _, id := range addedGroups {
group := account.GetGroup(id) group, ok := groupMap[id]
meta := map[string]any{"group": group.Name, "group_id": group.ID} if ok {
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
}
} }
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
for _, id := range removedGroups { for _, id := range removedGroups {
group := account.GetGroup(id) group, ok := groupMap[id]
meta := map[string]any{"group": group.Name, "group_id": group.ID} if ok {
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
}
} }
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
return nil return nil

View File

@ -97,13 +97,13 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
} }
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) updatedAccountSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings) resp := toAccountResponse(accountID, updatedAccountSettings)
util.WriteJSONObject(r.Context(), w, &resp) util.WriteJSONObject(r.Context(), w, &resp)
} }

View File

@ -89,7 +89,7 @@ type MockAccountManager struct {
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error) GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error)
LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
@ -667,7 +667,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us
} }
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface // UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error) {
if am.UpdateAccountSettingsFunc != nil { if am.UpdateAccountSettingsFunc != nil {
return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings) return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings)
} }