refactor account and dns settings

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-10-01 00:54:28 +03:00
parent 9e47c94a7f
commit 43eb7261e3
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
4 changed files with 97 additions and 76 deletions

View File

@ -128,7 +128,7 @@ type AccountManager interface {
GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error)
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error)
LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error)
@ -1048,7 +1048,16 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
// Only users with role UserRoleAdmin can update the account.
// User that performs the update has to belong to the account.
// Returns an updated Account
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) {
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account")
}
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@ -1058,53 +1067,57 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
}
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
oldSettings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
if err = am.validateExtraSettings(ctx, newSettings, oldSettings, userID, accountID); err != nil {
return nil, err
}
if !user.HasAdminPower() {
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account")
if err = am.Store.SaveAccountSettings(ctx, LockingStrengthUpdate, accountID, newSettings); err != nil {
return nil, fmt.Errorf("failed updating account settings: %w", err)
}
err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID)
if err != nil {
return nil, err
}
oldSettings := account.Settings
if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled {
event := activity.AccountPeerLoginExpirationEnabled
if !newSettings.PeerLoginExpirationEnabled {
event = activity.AccountPeerLoginExpirationDisabled
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
} else {
am.checkAndSchedulePeerLoginExpiration(ctx, account)
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
am.checkAndSchedulePeerLoginExpiration(ctx, account)
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
}
updatedAccount := account.UpdateSettings(newSettings)
err = am.Store.SaveAccount(ctx, account)
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, err
return nil, fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
return newSettings, nil
}
// validateExtraSettings validates the extra settings of the account.
func (am *DefaultAccountManager) validateExtraSettings(ctx context.Context, newSettings, oldSettings *Settings, userID, accountID string) error {
peers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
return updatedAccount, nil
peerMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peerMap[peer.ID] = peer
}
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peerMap, userID, accountID)
}
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
@ -1135,10 +1148,10 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
}
}
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *Account) {
am.peerLoginExpiry.Cancel(ctx, []string{account.Id})
if nextRun, ok := account.GetNextPeerExpiration(); ok {
go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id))
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) {
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok {
go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID))
}
}
@ -1674,33 +1687,18 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
// MarkPATUsed marks a personal access token as used
func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error {
user, err := am.Store.GetUserByTokenID(ctx, tokenID)
if err != nil {
return err
}
account, err := am.Store.GetAccountByUser(ctx, user.Id)
pat, err := am.Store.GetPATByID(ctx, LockingStrengthShare, tokenID, user.Id)
if err != nil {
return err
}
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()
account, err = am.Store.GetAccountByUser(ctx, user.Id)
if err != nil {
return err
}
pat, ok := account.Users[user.Id].PATs[tokenID]
if !ok {
return fmt.Errorf("token not found")
}
pat.LastUsed = time.Now().UTC()
return am.Store.SaveAccount(ctx, account)
return am.Store.SavePAT(ctx, LockingStrengthUpdate, pat)
}
// GetAccount returns an account associated with this account ID.

View File

@ -6,6 +6,7 @@ import (
"strconv"
"sync"
nbgroup "github.com/netbirdio/netbird/management/server/group"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
@ -94,56 +95,78 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
// SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := account.FindUser(userID)
if err != nil {
return err
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
}
if dnsSettingsToSave == nil {
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if !user.HasAdminPower() || user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
}
oldSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return err
}
groups, err := am.Store.GetAccountGroups(ctx, accountID)
if err != nil {
return err
}
if len(dnsSettingsToSave.DisabledManagementGroups) != 0 {
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups)
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, groups)
if err != nil {
return err
}
}
oldSettings := account.DNSSettings.Copy()
account.DNSSettings = dnsSettingsToSave.Copy()
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
if err = transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave); err != nil {
return fmt.Errorf("failed to update dns settings: %w", err)
}
return nil
})
if err != nil {
return err
}
groupMap := make(map[string]*nbgroup.Group, len(groups))
for _, g := range groups {
groupMap[g.ID] = g
}
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
for _, id := range addedGroups {
group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
group, ok := groupMap[id]
if ok {
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
}
}
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
for _, id := range removedGroups {
group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
group, ok := groupMap[id]
if ok {
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
}
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
return nil

View File

@ -97,13 +97,13 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
}
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
updatedAccountSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings)
resp := toAccountResponse(accountID, updatedAccountSettings)
util.WriteJSONObject(r.Context(), w, &resp)
}

View File

@ -89,7 +89,7 @@ type MockAccountManager struct {
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error)
LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
@ -667,7 +667,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us
}
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error) {
if am.UpdateAccountSettingsFunc != nil {
return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings)
}