From 16174f0478f533ba0d0fb5d00b6fd50a96dd9f1b Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 25 Sep 2024 12:52:42 +0300 Subject: [PATCH] Refactor route, setupkey, nameserver and dns to get record(s) from store Signed-off-by: bcmmbaga --- management/server/account.go | 5 +++ management/server/dns.go | 16 +++------- management/server/nameserver.go | 42 +++++-------------------- management/server/route.go | 38 ++++------------------ management/server/setupkey.go | 56 ++++++++++++--------------------- 5 files changed, 42 insertions(+), 115 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 6895c9378..11c3a17e0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -270,6 +270,11 @@ type AccountNetwork struct { Network *Network `gorm:"embedded;embeddedPrefix:network_"` } +// AccountDNSSettings used in gorm to only load dns settings and not whole account +type AccountDNSSettings struct { + DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` +} + type UserPermissions struct { DashboardView string `json:"dashboard_view"` } diff --git a/management/server/dns.go b/management/server/dns.go index 1d156c90a..7410aaa15 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -80,24 +80,16 @@ func (d DNSSettings) Copy() DNSSettings { // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings") } - dnsSettings := account.DNSSettings.Copy() - return &dnsSettings, nil + + return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) } // SaveDNSSettings validates a user role and updates the account's DNS settings diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 636f7cfee..e059d2217 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -19,30 +19,16 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { - - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") } - if !(user.HasAdminPower() || user.IsServiceUser) { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view nameserver groups") - } - - nsGroup, found := account.NameServerGroups[nsGroupID] - if found { - return nsGroup.Copy(), nil - } - - return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID) + return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID) } // CreateNameServerGroup creates and saves a new nameserver group @@ -159,30 +145,16 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco // ListNameServerGroups returns a list of nameserver groups from account func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) { - - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") } - nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups)) - for _, item := range account.NameServerGroups { - nsGroups = append(nsGroups, item.Copy()) - } - - return nsGroups, nil + return am.Store.GetAccountNameServerGroups(ctx, accountID) } func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error { diff --git a/management/server/route.go b/management/server/route.go index 11f89b83b..fbf4e82c5 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -17,29 +17,16 @@ import ( // GetRoute gets a route object from account and route IDs func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } - wantedRoute, found := account.Routes[routeID] - if found { - return wantedRoute, nil - } - - return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) + return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID) } // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. @@ -325,29 +312,16 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri // ListRoutes returns a list of routes from account func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } - routes := make([]*route.Route, 0, len(account.Routes)) - for _, item := range account.Routes { - routes = append(routes, item) - } - - return routes, nil + return am.Store.GetAccountRoutes(ctx, accountID) } func toProtocolRoute(route *route.Route) *proto.Route { diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 859f1b0b9..3e7be8a16 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -330,26 +330,24 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str // ListSetupKeys returns a list of all setup keys of the account func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") + } + + setupKeys, err := am.Store.GetAccountSetupKeys(ctx, accountID) if err != nil { return nil, err } - if !user.HasAdminPower() && !user.IsServiceUser { - return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies") - } - - keys := make([]*SetupKey, 0, len(account.SetupKeys)) - for _, key := range account.SetupKeys { + keys := make([]*SetupKey, 0, len(setupKeys)) + for _, key := range setupKeys { var k *SetupKey - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() { k = key.HiddenCopy(999) } else { k = key.Copy() @@ -362,44 +360,30 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") + } + + setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) if err != nil { return nil, err } - if !user.HasAdminPower() && !user.IsServiceUser { - return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies") - } - - var foundKey *SetupKey - for _, key := range account.SetupKeys { - if key.Id == keyID { - foundKey = key.Copy() - break - } - } - if foundKey == nil { - return nil, status.Errorf(status.NotFound, "setup key not found") - } - // the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file) - if foundKey.UpdatedAt.IsZero() { - foundKey.UpdatedAt = foundKey.CreatedAt + if setupKey.UpdatedAt.IsZero() { + setupKey.UpdatedAt = setupKey.CreatedAt } - if !(user.HasAdminPower() || user.IsServiceUser) { - foundKey = foundKey.HiddenCopy(999) + if !user.IsAdminOrServiceUser() { + setupKey = setupKey.HiddenCopy(999) } - return foundKey, nil + return setupKey, nil } func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {