From 7561706627421c757491eb6e5406f7a65661f40c Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 24 Sep 2024 19:55:33 +0300 Subject: [PATCH] add GetGroupByID from store and refactor Signed-off-by: bcmmbaga --- management/server/file_store.go | 3 +++ management/server/group.go | 44 +++++++++++++++++---------------- management/server/sql_store.go | 16 ++++++++++-- management/server/store.go | 1 + 4 files changed, 41 insertions(+), 23 deletions(-) diff --git a/management/server/file_store.go b/management/server/file_store.go index 316feb867..a18e0e539 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -983,6 +983,9 @@ func (s *FileStore) UpdateAccount(_ context.Context, _ LockingStrength, _ *Accou return nil } +func (s *FileStore) GetGroupByID(_ context.Context, _, _ string) (*nbgroup.Group, error) { + return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented") +} func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented") } diff --git a/management/server/group.go b/management/server/group.go index 9343f2dd2..60d895d0a 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -25,36 +25,38 @@ func (e *GroupLinkError) Error() string { return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name) } -// GetGroup object of the peers -func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { - groups, err := am.GetAllGroups(ctx, accountID, userID) - if err != nil { - return nil, err - } - - for _, group := range groups { - if group.ID == groupID { - return group, nil - } - } - - return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID) -} - -// GetAllGroups returns all groups in an account -func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) { +// CheckGroupPermissions validates if a user has the necessary permissions to view groups +func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { - return nil, err + return err } user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return nil, err + return err } if !user.HasAdminPower() && !user.IsServiceUser && settings.RegularUsersViewBlocked { - return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") + return status.Errorf(status.PermissionDenied, "groups are blocked for users") + } + + return nil +} + +// GetGroup returns a specific group by groupID in an account +func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { + if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { + return nil, err + } + + return am.Store.GetGroupByID(ctx, groupID, accountID) +} + +// GetAllGroups returns all groups in an account +func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { + if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { + return nil, err } return am.Store.GetAccountGroups(ctx, accountID) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index b76846c9f..d843e6f1d 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1087,12 +1087,24 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", "", status.Errorf(status.NotFound, "account not found") } - return "", "", status.Errorf(status.Internal, "failed to retrieve account fields") + return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error) } return account.Domain, account.DomainCategory, nil } +func (s *SqlStore) GetGroupByID(ctx context.Context, groupID, accountID string) (*nbgroup.Group, error) { + var group nbgroup.Group + result := s.db.WithContext(ctx).Model(&nbgroup.Group{}).Where(accountAndIDQueryCondition, accountID, groupID).First(&group) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "group not found") + } + return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error) + } + return &group, nil +} + // GetGroupByName retrieves a group by name and account ID. func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { var group nbgroup.Group @@ -1102,7 +1114,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "group not found") } - return nil, status.Errorf(status.Internal, "failed to retrieve group fields") + return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error) } return &group, nil } diff --git a/management/server/store.go b/management/server/store.go index 10a52db98..73e68531c 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -64,6 +64,7 @@ type Store interface { DeleteTokenID2UserIDIndex(tokenID string) error GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) + GetGroupByID(ctx context.Context, groupID, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error