From 8f98adddf6e2846cd818bfae85d7879f922a8d05 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Sun, 22 Sep 2024 15:14:31 +0300 Subject: [PATCH] refactor handlers to use GetAccountIDFromToken Signed-off-by: bcmmbaga --- management/server/account.go | 58 ++++-- management/server/group.go | 53 ++--- management/server/grpcserver.go | 2 +- management/server/http/accounts_handler.go | 46 ++--- .../server/http/dns_settings_handler.go | 8 +- management/server/http/events_handler.go | 6 +- .../server/http/geolocations_handler.go | 7 +- management/server/http/groups_handler.go | 119 +++++++----- management/server/http/nameservers_handler.go | 20 +- management/server/http/pat_handler.go | 18 +- management/server/http/peers_handler.go | 52 +++-- management/server/http/policies_handler.go | 183 ++++++++---------- .../server/http/posture_checks_handler.go | 44 ++--- management/server/http/routes_handler.go | 40 ++-- management/server/http/setupkeys_handler.go | 18 +- management/server/http/users_handler.go | 28 +-- management/server/mock_server/account_mock.go | 60 ++++-- management/server/policy.go | 71 +++++-- management/server/route.go | 14 ++ management/server/sql_store.go | 3 +- management/server/user.go | 17 +- 21 files changed, 485 insertions(+), 382 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 7c80ae6e5..fe2efee15 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -75,12 +75,14 @@ type AccountManager interface { SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) + GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) - GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) + GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) DeleteAccount(ctx context.Context, accountID, userID string) error MarkPATUsed(ctx context.Context, tokenID string) error + GetUserByID(ctx context.Context, id string) (*User, error) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) ListUsers(ctx context.Context, accountID string) ([]*User, error) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) @@ -107,7 +109,7 @@ type AccountManager interface { GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) - SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error + SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error DeletePolicy(ctx context.Context, accountID, policyID, userID string) error ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) @@ -145,6 +147,7 @@ type AccountManager interface { SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) + GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) } type DefaultAccountManager struct { @@ -1739,10 +1742,27 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st return account, user, pat, nil } -// GetAccountFromToken returns an account associated with this token. -func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) { +// GetAccountByID returns an account associated with this account ID. +func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err + } + + if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) { + return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") + } + + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + return am.Store.GetAccount(ctx, accountID) +} + +// GetAccountIDFromToken returns an account ID associated with this token. +func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { if claims.UserId == "" { - return nil, nil, fmt.Errorf("user ID is empty") + return "", "", fmt.Errorf("user ID is empty") } if am.singleAccountMode && am.singleAccountModeDomain != "" { // This section is mostly related to self-hosted installations. @@ -1754,28 +1774,27 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims) if err != nil { - return nil, nil, err + return "", "", err } user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) if err != nil { // this is not really possible because we got an account by user ID - return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId) + return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) } if !user.IsServiceUser && claims.Invited { err = am.redeemInvite(ctx, accountID, user.Id) if err != nil { - return nil, nil, err + return "", "", err } } if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil { - return nil, nil, err + return "", "", err } - // TODO: return account id, user id and error - return &Account{Id: accountID}, user, nil + return accountID, user.Id, nil } // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, @@ -2049,12 +2068,12 @@ func (am *DefaultAccountManager) GetDNSDomain() string { // CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT // group propagation and set the list of groups with access permissions. func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { - account, _, err := am.GetAccountFromToken(ctx, claims) + accountID, _, err := am.GetAccountIDFromToken(ctx, claims) if err != nil { return err } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, account.Id) + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return err } @@ -2133,6 +2152,19 @@ func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Stor return newLabel, nil } +func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err + } + + if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) { + return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") + } + + return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) +} + // addAllGroup to account object if it doesn't exist func addAllGroup(account *Account) error { if len(account.Groups) == 0 { diff --git a/management/server/group.go b/management/server/group.go index 49720f347..3f69c52ae 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -27,26 +27,15 @@ func (e *GroupLinkError) Error() string { // GetGroup object of the peers func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + groups, err := am.GetAllGroups(ctx, accountID, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { - return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") - } - - group, ok := account.Groups[groupID] - if ok { - return group, nil + for _, group := range groups { + if group.ID == groupID { + return group, nil + } } return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID) @@ -54,43 +43,32 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI // GetAllGroups returns all groups in an account func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } - user, err := account.FindUser(userID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + if !user.HasAdminPower() && !user.IsServiceUser && settings.RegularUsersViewBlocked { return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") } - groups := make([]*nbgroup.Group, 0, len(account.Groups)) - for _, item := range account.Groups { - groups = append(groups, item) - } - - return groups, nil + return am.Store.GetAccountGroups(ctx, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + groups, err := am.Store.GetAccountGroups(ctx, accountID) if err != nil { return nil, err } matchingGroups := make([]*nbgroup.Group, 0) - for _, group := range account.Groups { + for _, group := range groups { if group.Name == groupName { matchingGroups = append(matchingGroups, group) } @@ -262,6 +240,15 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use return nil } + allGroup, err := account.GetGroupAll() + if err != nil { + return err + } + + if allGroup.ID == groupID { + return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") + } + if err = validateDeleteGroup(account, group, userId); err != nil { return err } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 5d7094b6a..cda3bc748 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -262,7 +262,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string } claims := s.jwtClaimsExtractor.FromToken(token) // we need to call this method because if user is new, we will automatically add it to existing or create a new account - _, _, err = s.accountManager.GetAccountFromToken(ctx, claims) + _, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims) if err != nil { return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) } diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index ffa5b9a28..91caa1512 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -35,25 +35,26 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) * // GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - if !(user.HasAdminPower() || user.IsServiceUser) { - util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w) + settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) return } - resp := toAccountResponse(account) + resp := toAccountResponse(accountID, settings) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } // UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - _, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + _, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -96,24 +97,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) settings.JWTAllowGroups = *req.Settings.JwtAllowGroups } - updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings) + updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) if err != nil { util.WriteError(r.Context(), err, w) return } - resp := toAccountResponse(updatedAccount) + resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings) util.WriteJSONObject(r.Context(), w, &resp) } // DeleteAccount is a HTTP DELETE handler to delete an account func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodDelete { - util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) - return - } - claims := h.claimsExtractor.FromRequestContext(r) vars := mux.Vars(r) targetAccountID := vars["accountId"] @@ -131,28 +127,28 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) util.WriteJSONObject(r.Context(), w, emptyObject{}) } -func toAccountResponse(account *server.Account) *api.Account { - jwtAllowGroups := account.Settings.JWTAllowGroups +func toAccountResponse(accountID string, settings *server.Settings) *api.Account { + jwtAllowGroups := settings.JWTAllowGroups if jwtAllowGroups == nil { jwtAllowGroups = []string{} } - settings := api.AccountSettings{ - PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()), - PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled, - GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled, - JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled, - JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName, + apiSettings := api.AccountSettings{ + PeerLoginExpiration: int(settings.PeerLoginExpiration.Seconds()), + PeerLoginExpirationEnabled: settings.PeerLoginExpirationEnabled, + GroupsPropagationEnabled: &settings.GroupsPropagationEnabled, + JwtGroupsEnabled: &settings.JWTGroupsEnabled, + JwtGroupsClaimName: &settings.JWTGroupsClaimName, JwtAllowGroups: &jwtAllowGroups, - RegularUsersViewBlocked: account.Settings.RegularUsersViewBlocked, + RegularUsersViewBlocked: settings.RegularUsersViewBlocked, } - if account.Settings.Extra != nil { - settings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &account.Settings.Extra.PeerApprovalEnabled} + if settings.Extra != nil { + apiSettings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &settings.Extra.PeerApprovalEnabled} } return &api.Account{ - Id: account.Id, - Settings: settings, + Id: accountID, + Settings: apiSettings, } } diff --git a/management/server/http/dns_settings_handler.go b/management/server/http/dns_settings_handler.go index 74b0e1a55..13c2101a7 100644 --- a/management/server/http/dns_settings_handler.go +++ b/management/server/http/dns_settings_handler.go @@ -32,14 +32,14 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg // GetDNSSettings returns the DNS settings for the account func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id) + dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -55,7 +55,7 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque // UpdateDNSSettings handles update to DNS settings of an account func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -72,7 +72,7 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re DisabledManagementGroups: req.DisabledManagementGroups, } - err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings) + err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/events_handler.go b/management/server/http/events_handler.go index 428b4c164..ee0c63f28 100644 --- a/management/server/http/events_handler.go +++ b/management/server/http/events_handler.go @@ -34,14 +34,14 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev // GetAllEvents list of the given account func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id) + accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -51,7 +51,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { events[i] = toEventResponse(e) } - err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id) + err = h.fillEventsWithUserInfo(r.Context(), events, accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/geolocations_handler.go b/management/server/http/geolocations_handler.go index af4d3116f..418228abf 100644 --- a/management/server/http/geolocations_handler.go +++ b/management/server/http/geolocations_handler.go @@ -98,7 +98,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http. func (l *GeolocationsHandler) authenticateUser(r *http.Request) error { claims := l.claimsExtractor.FromRequestContext(r) - _, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims) + _, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims) + if err != nil { + return err + } + + user, err := l.accountManager.GetUserByID(r.Context(), userID) if err != nil { return err } diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index c622d873a..f369d1a00 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/gorilla/mux" + nbpeer "github.com/netbirdio/netbird/management/server/peer" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" @@ -35,14 +36,20 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr // GetAllGroups list for the account func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id) + groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -50,7 +57,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { groupsResponse := make([]*api.Group, 0, len(groups)) for _, group := range groups { - groupsResponse = append(groupsResponse, toGroupResponse(account, group)) + groupsResponse = append(groupsResponse, toGroupResponse(accountPeers, group)) } util.WriteJSONObject(r.Context(), w, groupsResponse) @@ -59,7 +66,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { // UpdateGroup handles update to a group identified by a given ID func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -76,17 +83,18 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { return } - eg, ok := account.Groups[groupID] - if !ok { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w) - return - } - - allGroup, err := account.GetGroupAll() + existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } + + allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + if allGroup.ID == groupID { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w) return @@ -114,23 +122,29 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { ID: groupID, Name: req.Name, Peers: peers, - Issued: eg.Issued, - IntegrationReference: eg.IntegrationReference, + Issued: existingGroup.Issued, + IntegrationReference: existingGroup.IntegrationReference, } - if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil { - log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err) + if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group); err != nil { + log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err) util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group)) + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) } // CreateGroup handles group creation request func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -160,24 +174,29 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { Issued: nbgroup.GroupIssuedAPI, } - err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group) + err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group) if err != nil { util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group)) + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) } // DeleteGroup handles group deletion request func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - aID := account.Id groupID := mux.Vars(r)["groupId"] if len(groupID) == 0 { @@ -185,18 +204,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { return } - allGroup, err := account.GetGroupAll() - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - if allGroup.ID == groupID { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w) - return - } - - err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID) + err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID) if err != nil { _, ok := err.(*server.GroupLinkError) if ok { @@ -213,34 +221,39 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { // GetGroup returns a group func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + groupID := mux.Vars(r)["groupId"] + if len(groupID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) + return + } + + group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } - switch r.Method { - case http.MethodGet: - groupID := mux.Vars(r)["groupId"] - if len(groupID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) - return - } - - group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group)) - default: - util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w) + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) return } + + util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group)) + } -func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group { +func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { + peersMap := make(map[string]*nbpeer.Peer, len(peers)) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + cache := make(map[string]api.PeerMinimum) gr := api.Group{ Id: group.ID, @@ -251,7 +264,7 @@ func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group { for _, pid := range group.Peers { _, ok := cache[pid] if !ok { - peer, ok := account.Peers[pid] + peer, ok := peersMap[pid] if !ok { continue } diff --git a/management/server/http/nameservers_handler.go b/management/server/http/nameservers_handler.go index c6e00bb2d..e7a2bc2ae 100644 --- a/management/server/http/nameservers_handler.go +++ b/management/server/http/nameservers_handler.go @@ -36,14 +36,14 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg // GetAllNameservers returns the list of nameserver groups for the account func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id) + nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -60,7 +60,7 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re // CreateNameserverGroup handles nameserver group creation request func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -79,7 +79,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt return } - nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled) + nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled) if err != nil { util.WriteError(r.Context(), err, w) return @@ -93,7 +93,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt // UpdateNameserverGroup handles update to a nameserver group identified by a given ID func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -130,7 +130,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt SearchDomainsEnabled: req.SearchDomainsEnabled, } - err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup) + err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup) if err != nil { util.WriteError(r.Context(), err, w) return @@ -144,7 +144,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt // DeleteNameserverGroup handles nameserver group deletion request func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -156,7 +156,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt return } - err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id) + err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -168,7 +168,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt // GetNameserverGroup handles a nameserver group Get request identified by ID func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -181,7 +181,7 @@ func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.R return } - nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID) + nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go index 9d8448d3d..dfa9563e3 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/pat_handler.go @@ -34,20 +34,20 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH // GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) - userID := vars["userId"] + targetUserID := vars["userId"] if len(userID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID) + pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -64,7 +64,7 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { // GetToken is HTTP GET handler that returns a personal access token for the given user func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -83,7 +83,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { return } - pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) + pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -95,7 +95,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { // CreateToken is HTTP POST handler that creates a personal access token for the given user func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -115,7 +115,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { return } - pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn) + pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn) if err != nil { util.WriteError(r.Context(), err, w) return @@ -127,7 +127,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { // DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -146,7 +146,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) + err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 5a2190d83..4fbbc3106 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -74,7 +74,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) } -func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -96,7 +96,7 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, } } - peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update) + peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update) if err != nil { util.WriteError(ctx, err, w) return @@ -130,7 +130,7 @@ func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, // HandlePeer handles all peer requests for GET, PUT and DELETE operations func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -144,13 +144,20 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodDelete: - h.deletePeer(r.Context(), account.Id, user.Id, peerID, w) + h.deletePeer(r.Context(), accountID, userID, peerID, w) return - case http.MethodPut: - h.updatePeer(r.Context(), account, user, peerID, w, r) - return - case http.MethodGet: - h.getPeer(r.Context(), account, peerID, user.Id, w) + case http.MethodGet, http.MethodPut: + account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + if r.Method == http.MethodGet { + h.getPeer(r.Context(), account, peerID, userID, w) + } else { + h.updatePeer(r.Context(), account, userID, peerID, w, r) + } return default: util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) @@ -159,19 +166,14 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { // GetAllPeers returns a list of all peers associated with a provided account func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) - return - } - claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id) + account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -179,8 +181,8 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() - respBody := make([]*api.PeerBatch, 0, len(peers)) - for _, peer := range peers { + respBody := make([]*api.PeerBatch, 0, len(account.Peers)) + for _, peer := range account.Peers { peerToReturn, err := h.checkPeerStatus(peer) if err != nil { util.WriteError(r.Context(), err, w) @@ -214,7 +216,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -227,6 +229,18 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request return } + account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + user, err := account.FindUser(userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + // If the user is regular user and does not own the peer // with the given peerID return an empty list if !user.HasAdminPower() && !user.IsServiceUser { diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 9622668f4..1b0992cdd 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -3,9 +3,11 @@ package http import ( "encoding/json" "net/http" + "slices" "strconv" "github.com/gorilla/mux" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/rs/xid" "github.com/netbirdio/netbird/management/server" @@ -35,21 +37,27 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) * // GetAllPolicies list for the account func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id) + listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } - policies := []*api.Policy{} - for _, policy := range accountPolicies { - resp := toPolicyResponse(account, policy) + allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policies := make([]*api.Policy, 0, len(listPolicies)) + for _, policy := range listPolicies { + resp := toPolicyResponse(allGroups, policy) if len(resp.Rules) == 0 { util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return @@ -63,7 +71,7 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { // UpdatePolicy handles update to a policy identified by a given ID func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -76,41 +84,35 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { return } - policyIdx := -1 - for i, policy := range account.Policies { - if policy.ID == policyID { - policyIdx = i - break - } - } - if policyIdx < 0 { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w) - return - } - - h.savePolicy(w, r, account, user, policyID) -} - -// CreatePolicy handles policy creation request -func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } - h.savePolicy(w, r, account, user, "") + policyIdx := slices.IndexFunc(account.Policies, func(policy *server.Policy) bool { return policy.ID == policyID }) + if policyIdx < 0 { + util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w) + return + } + + h.savePolicy(w, r, accountID, userID, policyID) +} + +// CreatePolicy handles policy creation request +func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + h.savePolicy(w, r, accountID, userID, "") } // savePolicy handles policy creation and update -func (h *Policies) savePolicy( - w http.ResponseWriter, - r *http.Request, - account *server.Account, - user *server.User, - policyID string, -) { +func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) { var req api.PutApiPoliciesPolicyIdJSONRequestBody if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -127,6 +129,8 @@ func (h *Policies) savePolicy( return } + isUpdate := policyID != "" + if policyID == "" { policyID = xid.New().String() } @@ -141,8 +145,8 @@ func (h *Policies) savePolicy( pr := server.PolicyRule{ ID: policyID, // TODO: when policy can contain multiple rules, need refactor Name: rule.Name, - Destinations: groupMinimumsToStrings(account, rule.Destinations), - Sources: groupMinimumsToStrings(account, rule.Sources), + Destinations: rule.Destinations, + Sources: rule.Sources, Bidirectional: rule.Bidirectional, } @@ -206,16 +210,18 @@ func (h *Policies) savePolicy( policy.Rules = append(policy.Rules, &pr) } - if req.SourcePostureChecks != nil { - policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks) - } - - if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil { + if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil { util.WriteError(r.Context(), err, w) return } - resp := toPolicyResponse(account, &policy) + allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resp := toPolicyResponse(allGroups, &policy) if len(resp.Rules) == 0 { util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return @@ -227,12 +233,11 @@ func (h *Policies) savePolicy( // DeletePolicy handles policy deletion request func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - aID := account.Id vars := mux.Vars(r) policyID := vars["policyId"] @@ -241,7 +246,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { return } - if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil { + if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil { util.WriteError(r.Context(), err, w) return } @@ -252,40 +257,46 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { // GetPolicy handles a group Get request identified by ID func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - switch r.Method { - case http.MethodGet: - vars := mux.Vars(r) - policyID := vars["policyId"] - if len(policyID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w) - return - } - - policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - resp := toPolicyResponse(account, policy) - if len(resp.Rules) == 0 { - util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) - return - } - - util.WriteJSONObject(r.Context(), w, resp) - default: - util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w) + vars := mux.Vars(r) + policyID := vars["policyId"] + if len(policyID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w) + return } + + policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resp := toPolicyResponse(allGroups, policy) + if len(resp.Rules) == 0 { + util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) + return + } + + util.WriteJSONObject(r.Context(), w, resp) } -func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy { +func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy { + groupsMap := make(map[string]*nbgroup.Group) + for _, group := range groups { + groupsMap[group.ID] = group + } + cache := make(map[string]api.GroupMinimum) ap := &api.Policy{ Id: &policy.ID, @@ -306,16 +317,18 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic Protocol: api.PolicyRuleProtocol(r.Protocol), Action: api.PolicyRuleAction(r.Action), } + if len(r.Ports) != 0 { portsCopy := r.Ports rule.Ports = &portsCopy } + for _, gid := range r.Sources { _, ok := cache[gid] if ok { continue } - if group, ok := account.Groups[gid]; ok { + if group, ok := groupsMap[gid]; ok { minimum := api.GroupMinimum{ Id: group.ID, Name: group.Name, @@ -325,13 +338,14 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic cache[gid] = minimum } } + for _, gid := range r.Destinations { cachedMinimum, ok := cache[gid] if ok { rule.Destinations = append(rule.Destinations, cachedMinimum) continue } - if group, ok := account.Groups[gid]; ok { + if group, ok := groupsMap[gid]; ok { minimum := api.GroupMinimum{ Id: group.ID, Name: group.Name, @@ -345,28 +359,3 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic } return ap } - -func groupMinimumsToStrings(account *server.Account, gm []string) []string { - result := make([]string, 0, len(gm)) - for _, g := range gm { - if _, ok := account.Groups[g]; !ok { - continue - } - result = append(result, g) - } - return result -} - -func sourcePostureChecksToStrings(account *server.Account, postureChecksIds []string) []string { - result := make([]string, 0, len(postureChecksIds)) - for _, id := range postureChecksIds { - for _, postureCheck := range account.PostureChecks { - if id == postureCheck.ID { - result = append(result, id) - continue - } - } - - } - return result -} diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go index 059cb3b80..0ab2b3a88 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/posture_checks_handler.go @@ -37,20 +37,20 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa // GetAllPostureChecks list for the account func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id) + listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } - postureChecks := []*api.PostureCheck{} - for _, postureCheck := range accountPostureChecks { + postureChecks := make([]*api.PostureCheck, 0, len(listPostureChecks)) + for _, postureCheck := range listPostureChecks { postureChecks = append(postureChecks, postureCheck.ToAPIResponse()) } @@ -60,7 +60,7 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt // UpdatePostureCheck handles update to a posture check identified by a given ID func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -73,37 +73,31 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http return } - postureChecksIdx := -1 - for i, postureCheck := range account.PostureChecks { - if postureCheck.ID == postureChecksID { - postureChecksIdx = i - break - } - } - if postureChecksIdx < 0 { + _, err = p.accountManager.GetPostureChecks(r.Context(), accountID, userID, postureChecksID) + if err != nil { util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w) return } - p.savePostureChecks(w, r, account, user, postureChecksID) + p.savePostureChecks(w, r, accountID, userID, postureChecksID) } // CreatePostureCheck handles posture check creation request func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - p.savePostureChecks(w, r, account, user, "") + p.savePostureChecks(w, r, accountID, userID, "") } // GetPostureCheck handles a posture check Get request identified by ID func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -116,7 +110,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re return } - postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id) + postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -128,7 +122,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re // DeletePostureCheck handles posture check deletion request func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -141,7 +135,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http return } - if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil { + if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil { util.WriteError(r.Context(), err, w) return } @@ -150,13 +144,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http } // savePostureChecks handles posture checks create and update -func (p *PostureChecksHandler) savePostureChecks( - w http.ResponseWriter, - r *http.Request, - account *server.Account, - user *server.User, - postureChecksID string, -) { +func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) { var ( err error req api.PostureCheckUpdate @@ -181,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks( return } - if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil { + if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index 18c347334..0932e6445 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -43,13 +43,13 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro // GetAllRoutes returns the list of routes for the account func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id) + routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -70,7 +70,7 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { // CreateRoute handles route creation request func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -117,15 +117,9 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { peerGroupIds = *req.PeerGroups } - // Do not allow non-Linux peers - if peer := account.GetPeer(peerId); peer != nil { - if peer.Meta.GoOS != "linux" { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w) - return - } - } - - newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute) + newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds, + req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, userID, req.KeepRoute, + ) if err != nil { util.WriteError(r.Context(), err, w) return @@ -168,7 +162,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro // UpdateRoute handles update to a route identified by a given ID func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -181,7 +175,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { return } - _, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) + _, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -204,14 +198,6 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { peerID = *req.Peer } - // do not allow non Linux peers - if peer := account.GetPeer(peerID); peer != nil { - if peer.Meta.GoOS != "linux" { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w) - return - } - } - newRoute := &route.Route{ ID: route.ID(routeID), NetID: route.NetID(req.NetworkId), @@ -247,7 +233,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { newRoute.PeerGroups = *req.PeerGroups } - err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute) + err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute) if err != nil { util.WriteError(r.Context(), err, w) return @@ -265,7 +251,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { // DeleteRoute handles route deletion request func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -277,7 +263,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id) + err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -289,7 +275,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { // GetRoute handles a route Get request identified by ID func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -301,7 +287,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { return } - foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) + foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID) if err != nil { util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w) return diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go index 8ee7dfaba..8514f0b55 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/setupkeys_handler.go @@ -35,7 +35,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) // CreateSetupKey is a POST requests that creates a new SetupKey func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -76,8 +76,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request if req.Ephemeral != nil { ephemeral = *req.Ephemeral } - setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn, - req.AutoGroups, req.UsageLimit, user.Id, ephemeral) + setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn, + req.AutoGroups, req.UsageLimit, userID, ephemeral) if err != nil { util.WriteError(r.Context(), err, w) return @@ -89,7 +89,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request // GetSetupKey is a GET request to get a SetupKey by ID func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -102,7 +102,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { return } - key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID) + key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -114,7 +114,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { // UpdateSetupKey is a PUT request to update server.SetupKey func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -150,7 +150,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request newKey.Name = req.Name newKey.Id = keyID - newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id) + newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -161,13 +161,13 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request // GetAllSetupKeys is a GET request that returns a list of SetupKey func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id) + setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/users_handler.go b/management/server/http/users_handler.go index 2c2aed842..e36b11729 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/users_handler.go @@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) - userID := vars["userId"] + targetUserID := vars["userId"] if len(userID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - existingUser, ok := account.Users[userID] - if !ok { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w) + existingUser, err := h.accountManager.GetUserByID(r.Context(), targetUserID) + if err != nil { + util.WriteError(r.Context(), err, w) return } @@ -78,7 +78,7 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { return } - newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{ + newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{ Id: userID, Role: userRole, AutoGroups: req.AutoGroups, @@ -102,7 +102,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -115,7 +115,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID) + err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -132,7 +132,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { name = *req.Name } - newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{ + newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{ Email: email, Name: name, Role: req.Role, @@ -184,13 +184,13 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id) + data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -231,7 +231,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -244,7 +244,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID) + err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 9ef42bff2..a43e5a18c 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -26,7 +26,7 @@ type MockAccountManager struct { CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) - GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error) + GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) @@ -48,7 +48,7 @@ type MockAccountManager struct { GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) - SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) error + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) @@ -79,7 +79,7 @@ type MockAccountManager struct { DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) - GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) + GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error DeleteAccountFunc func(ctx context.Context, accountID, userID string) error GetDNSDomainFunc func() string @@ -105,6 +105,9 @@ type MockAccountManager struct { SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) + GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error) + GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error) + GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error) } func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { @@ -190,16 +193,14 @@ func (am *MockAccountManager) CreateSetupKey( return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") } -// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface -func (am *MockAccountManager) GetAccountByUserOrAccountID( - ctx context.Context, userId, accountId, domain string, -) (*server.Account, error) { - if am.GetAccountByUserOrAccountIdFunc != nil { - return am.GetAccountByUserOrAccountIdFunc(ctx, userId, accountId, domain) +// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface +func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) { + if am.GetAccountIDByUserOrAccountIdFunc != nil { + return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain) } - return nil, status.Errorf( + return "", status.Errorf( codes.Unimplemented, - "method GetAccountByUserOrAccountID is not implemented", + "method GetAccountIDByUserOrAccountID is not implemented", ) } @@ -377,9 +378,9 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID } // SavePolicy mock implementation of SavePolicy from server.AccountManager interface -func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) error { +func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error { if am.SavePolicyFunc != nil { - return am.SavePolicyFunc(ctx, accountID, userID, policy) + return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate) } return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") } @@ -601,12 +602,12 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented") } -// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface -func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - if am.GetAccountFromTokenFunc != nil { - return am.GetAccountFromTokenFunc(ctx, claims) +// GetAccountIDFromToken mocks GetAccountIDFromToken of the AccountManager interface +func (am *MockAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + if am.GetAccountIDFromTokenFunc != nil { + return am.GetAccountIDFromTokenFunc(ctx, claims) } - return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented") + return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromToken is not implemented") } func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { @@ -800,3 +801,26 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe } return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented") } + +// GetAccountByID mocks GetAccountByID of the AccountManager interface +func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) { + if am.GetAccountByIDFunc != nil { + return am.GetAccountByIDFunc(ctx, accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountByID is not implemented") +} + +// GetUserByID mocks GetUserByID of the AccountManager interface +func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) { + if am.GetUserByIDFunc != nil { + return am.GetUserByIDFunc(ctx, id) + } + return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented") +} + +func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) { + if am.GetAccountSettingsFunc != nil { + return am.GetAccountSettingsFunc(ctx, accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented") +} diff --git a/management/server/policy.go b/management/server/policy.go index aaf9b6e72..833f97d39 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -3,6 +3,7 @@ package server import ( "context" _ "embed" + "slices" "strconv" "strings" @@ -341,7 +342,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic } // SavePolicy in the store -func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error { +func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -350,7 +351,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } - exists := am.savePolicy(account, policy) + if err = am.savePolicy(account, policy, isUpdate); err != nil { + return err + } account.Network.IncSerial() if err = am.Store.SaveAccount(ctx, account); err != nil { @@ -358,7 +361,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user } action := activity.PolicyAdded - if exists { + if isUpdate { action = activity.PolicyUpdated } am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) @@ -434,18 +437,34 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) return policy, nil } -func (am *DefaultAccountManager) savePolicy(account *Account, policy *Policy) (exists bool) { - for i, p := range account.Policies { - if p.ID == policy.ID { - account.Policies[i] = policy - exists = true - break +// savePolicy saves or updates a policy in the given account. +// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy. +func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error { + for index, rule := range policyToSave.Rules { + rule.Sources = filterValidGroupIDs(account, rule.Sources) + rule.Destinations = filterValidGroupIDs(account, rule.Destinations) + policyToSave.Rules[index] = rule + } + + if policyToSave.SourcePostureChecks != nil { + policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks) + } + + if isUpdate { + policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) + if policyIdx < 0 { + return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) } + + // Update the existing policy + account.Policies[policyIdx] = policyToSave + return nil } - if !exists { - account.Policies = append(account.Policies, policy) - } - return + + // Add the new policy to the account + account.Policies = append(account.Policies, policyToSave) + + return nil } func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { @@ -560,3 +579,29 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { } return nil } + +// filterValidPostureChecks filters and returns the posture check IDs from the given list +// that are valid within the provided account. +func filterValidPostureChecks(account *Account, postureChecksIds []string) []string { + result := make([]string, 0, len(postureChecksIds)) + for _, id := range postureChecksIds { + for _, postureCheck := range account.PostureChecks { + if id == postureCheck.ID { + result = append(result, id) + continue + } + } + } + return result +} + +// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map. +func filterValidGroupIDs(account *Account, groupIDs []string) []string { + result := make([]string, 0, len(groupIDs)) + for _, groupID := range groupIDs { + if _, exists := account.Groups[groupID]; exists { + result = append(result, groupID) + } + } + return result +} diff --git a/management/server/route.go b/management/server/route.go index 064f3c105..11f89b83b 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -134,6 +134,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, err } + // Do not allow non-Linux peers + if peer := account.GetPeer(peerID); peer != nil { + if peer.Meta.GoOS != "linux" { + return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + } + } + if len(domains) > 0 && prefix.IsValid() { return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") } @@ -234,6 +241,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } + // Do not allow non-Linux peers + if peer := account.GetPeer(routeToSave.Peer); peer != nil { + if peer.Meta.GoOS != "linux" { + return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + } + } + if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 6a667b398..b4bcbfbd0 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -36,6 +36,7 @@ const ( idQueryCondition = "id = ?" keyQueryCondition = "key = ?" accountAndIDQueryCondition = "account_id = ? and id = ?" + accountIDCondition = "account_id = ?" peerNotFoundFMT = "peer %s not found" ) @@ -500,7 +501,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { var groups []*nbgroup.Group - result := s.db.Find(&groups, idQueryCondition, accountID) + result := s.db.Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") diff --git a/management/server/user.go b/management/server/user.go index 193333685..3c2feec9f 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -357,26 +357,35 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return newUser.ToUserInfo(idpUser, account.Settings) } +func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) { + return am.Store.GetUserByUserID(ctx, LockingStrengthShare, id) +} + // GetUser looks up a user by provided authorization claims. // It will also create an account if didn't exist for this user before. func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) { - account, user, err := am.GetAccountFromToken(ctx, claims) + accountID, userID, err := am.GetAccountIDFromToken(ctx, claims) if err != nil { return nil, fmt.Errorf("failed to get account with token claims %v", err) } - // this code should be outside of the am.GetAccountFromToken(claims) because this method is called also by the gRPC + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err + } + + // this code should be outside of the am.GetAccountIDFromToken(claims) because this method is called also by the gRPC // server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event. newLogin := user.LastDashboardLoginChanged(claims.LastLogin) - err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin) + err = am.Store.SaveUserLastLogin(ctx, accountID, userID, claims.LastLogin) if err != nil { log.WithContext(ctx).Errorf("failed saving user last login: %v", err) } if newLogin { meta := map[string]any{"timestamp": claims.LastLogin} - am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.DashboardLogin, meta) + am.StoreEvent(ctx, claims.UserId, claims.UserId, accountID, activity.DashboardLogin, meta) } return user, nil