refactor handlers to use GetAccountIDFromToken

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-09-22 15:14:31 +03:00
parent 26dd045da5
commit 8f98adddf6
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
21 changed files with 485 additions and 382 deletions

View File

@ -75,12 +75,14 @@ type AccountManager interface {
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error)
GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
DeleteAccount(ctx context.Context, accountID, userID string) error DeleteAccount(ctx context.Context, accountID, userID string) error
MarkPATUsed(ctx context.Context, tokenID string) error MarkPATUsed(ctx context.Context, tokenID string) error
GetUserByID(ctx context.Context, id string) (*User, error)
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(ctx context.Context, accountID string) ([]*User, error) ListUsers(ctx context.Context, accountID string) ([]*User, error)
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
@ -107,7 +109,7 @@ type AccountManager interface {
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
@ -145,6 +147,7 @@ type AccountManager interface {
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error)
} }
type DefaultAccountManager struct { type DefaultAccountManager struct {
@ -1739,10 +1742,27 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st
return account, user, pat, nil return account, user, pat, nil
} }
// GetAccountFromToken returns an account associated with this token. // GetAccountByID returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) { func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
}
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
return am.Store.GetAccount(ctx, accountID)
}
// GetAccountIDFromToken returns an account ID associated with this token.
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
if claims.UserId == "" { if claims.UserId == "" {
return nil, nil, fmt.Errorf("user ID is empty") return "", "", fmt.Errorf("user ID is empty")
} }
if am.singleAccountMode && am.singleAccountModeDomain != "" { if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations. // This section is mostly related to self-hosted installations.
@ -1754,28 +1774,27 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims) accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims)
if err != nil { if err != nil {
return nil, nil, err return "", "", err
} }
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
if err != nil { if err != nil {
// this is not really possible because we got an account by user ID // this is not really possible because we got an account by user ID
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId) return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
} }
if !user.IsServiceUser && claims.Invited { if !user.IsServiceUser && claims.Invited {
err = am.redeemInvite(ctx, accountID, user.Id) err = am.redeemInvite(ctx, accountID, user.Id)
if err != nil { if err != nil {
return nil, nil, err return "", "", err
} }
} }
if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil { if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil {
return nil, nil, err return "", "", err
} }
// TODO: return account id, user id and error return accountID, user.Id, nil
return &Account{Id: accountID}, user, nil
} }
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
@ -2049,12 +2068,12 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT // CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
// group propagation and set the list of groups with access permissions. // group propagation and set the list of groups with access permissions.
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
account, _, err := am.GetAccountFromToken(ctx, claims) accountID, _, err := am.GetAccountIDFromToken(ctx, claims)
if err != nil { if err != nil {
return err return err
} }
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, account.Id) settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err
} }
@ -2133,6 +2152,19 @@ func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Stor
return newLabel, nil return newLabel, nil
} }
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
}
return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
}
// addAllGroup to account object if it doesn't exist // addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error { func addAllGroup(account *Account) error {
if len(account.Groups) == 0 { if len(account.Groups) == 0 {

View File

@ -27,26 +27,15 @@ func (e *GroupLinkError) Error() string {
// GetGroup object of the peers // GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) groups, err := am.GetAllGroups(ctx, accountID, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) for _, group := range groups {
if err != nil { if group.ID == groupID {
return nil, err return group, nil
} }
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
}
group, ok := account.Groups[groupID]
if ok {
return group, nil
} }
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID) return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
@ -54,43 +43,32 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
// GetAllGroups returns all groups in an account // GetAllGroups returns all groups in an account
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) { func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { if !user.HasAdminPower() && !user.IsServiceUser && settings.RegularUsersViewBlocked {
return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
} }
groups := make([]*nbgroup.Group, 0, len(account.Groups)) return am.Store.GetAccountGroups(ctx, accountID)
for _, item := range account.Groups {
groups = append(groups, item)
}
return groups, nil
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) groups, err := am.Store.GetAccountGroups(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
matchingGroups := make([]*nbgroup.Group, 0) matchingGroups := make([]*nbgroup.Group, 0)
for _, group := range account.Groups { for _, group := range groups {
if group.Name == groupName { if group.Name == groupName {
matchingGroups = append(matchingGroups, group) matchingGroups = append(matchingGroups, group)
} }
@ -262,6 +240,15 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
return nil return nil
} }
allGroup, err := account.GetGroupAll()
if err != nil {
return err
}
if allGroup.ID == groupID {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if err = validateDeleteGroup(account, group, userId); err != nil { if err = validateDeleteGroup(account, group, userId); err != nil {
return err return err
} }

View File

@ -262,7 +262,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
} }
claims := s.jwtClaimsExtractor.FromToken(token) claims := s.jwtClaimsExtractor.FromToken(token)
// we need to call this method because if user is new, we will automatically add it to existing or create a new account // we need to call this method because if user is new, we will automatically add it to existing or create a new account
_, _, err = s.accountManager.GetAccountFromToken(ctx, claims) _, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims)
if err != nil { if err != nil {
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
} }

View File

@ -35,25 +35,26 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. // GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) { func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
if !(user.HasAdminPower() || user.IsServiceUser) { settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID)
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w) if err != nil {
util.WriteError(r.Context(), err, w)
return return
} }
resp := toAccountResponse(account) resp := toAccountResponse(accountID, settings)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
} }
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) // UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) { func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
_, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) _, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -96,24 +97,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
} }
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings) updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
resp := toAccountResponse(updatedAccount) resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings)
util.WriteJSONObject(r.Context(), w, &resp) util.WriteJSONObject(r.Context(), w, &resp)
} }
// DeleteAccount is a HTTP DELETE handler to delete an account // DeleteAccount is a HTTP DELETE handler to delete an account
func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) { func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
vars := mux.Vars(r) vars := mux.Vars(r)
targetAccountID := vars["accountId"] targetAccountID := vars["accountId"]
@ -131,28 +127,28 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
util.WriteJSONObject(r.Context(), w, emptyObject{}) util.WriteJSONObject(r.Context(), w, emptyObject{})
} }
func toAccountResponse(account *server.Account) *api.Account { func toAccountResponse(accountID string, settings *server.Settings) *api.Account {
jwtAllowGroups := account.Settings.JWTAllowGroups jwtAllowGroups := settings.JWTAllowGroups
if jwtAllowGroups == nil { if jwtAllowGroups == nil {
jwtAllowGroups = []string{} jwtAllowGroups = []string{}
} }
settings := api.AccountSettings{ apiSettings := api.AccountSettings{
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()), PeerLoginExpiration: int(settings.PeerLoginExpiration.Seconds()),
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled, PeerLoginExpirationEnabled: settings.PeerLoginExpirationEnabled,
GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled, GroupsPropagationEnabled: &settings.GroupsPropagationEnabled,
JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled, JwtGroupsEnabled: &settings.JWTGroupsEnabled,
JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName, JwtGroupsClaimName: &settings.JWTGroupsClaimName,
JwtAllowGroups: &jwtAllowGroups, JwtAllowGroups: &jwtAllowGroups,
RegularUsersViewBlocked: account.Settings.RegularUsersViewBlocked, RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
} }
if account.Settings.Extra != nil { if settings.Extra != nil {
settings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &account.Settings.Extra.PeerApprovalEnabled} apiSettings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &settings.Extra.PeerApprovalEnabled}
} }
return &api.Account{ return &api.Account{
Id: account.Id, Id: accountID,
Settings: settings, Settings: apiSettings,
} }
} }

View File

@ -32,14 +32,14 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetDNSSettings returns the DNS settings for the account // GetDNSSettings returns the DNS settings for the account
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) { func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
log.WithContext(r.Context()).Error(err) log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id) dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -55,7 +55,7 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
// UpdateDNSSettings handles update to DNS settings of an account // UpdateDNSSettings handles update to DNS settings of an account
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -72,7 +72,7 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
DisabledManagementGroups: req.DisabledManagementGroups, DisabledManagementGroups: req.DisabledManagementGroups,
} }
err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings) err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@ -34,14 +34,14 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev
// GetAllEvents list of the given account // GetAllEvents list of the given account
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
log.WithContext(r.Context()).Error(err) log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id) accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -51,7 +51,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
events[i] = toEventResponse(e) events[i] = toEventResponse(e)
} }
err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id) err = h.fillEventsWithUserInfo(r.Context(), events, accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@ -98,7 +98,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error { func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
claims := l.claimsExtractor.FromRequestContext(r) claims := l.claimsExtractor.FromRequestContext(r)
_, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims) _, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
return err
}
user, err := l.accountManager.GetUserByID(r.Context(), userID)
if err != nil { if err != nil {
return err return err
} }

View File

@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
@ -35,14 +36,20 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr
// GetAllGroups list for the account // GetAllGroups list for the account
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
log.WithContext(r.Context()).Error(err) log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id) groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -50,7 +57,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
groupsResponse := make([]*api.Group, 0, len(groups)) groupsResponse := make([]*api.Group, 0, len(groups))
for _, group := range groups { for _, group := range groups {
groupsResponse = append(groupsResponse, toGroupResponse(account, group)) groupsResponse = append(groupsResponse, toGroupResponse(accountPeers, group))
} }
util.WriteJSONObject(r.Context(), w, groupsResponse) util.WriteJSONObject(r.Context(), w, groupsResponse)
@ -59,7 +66,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
// UpdateGroup handles update to a group identified by a given ID // UpdateGroup handles update to a group identified by a given ID
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -76,17 +83,18 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
return return
} }
eg, ok := account.Groups[groupID] existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
return
}
allGroup, err := account.GetGroupAll()
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if allGroup.ID == groupID { if allGroup.ID == groupID {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
return return
@ -114,23 +122,29 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
ID: groupID, ID: groupID,
Name: req.Name, Name: req.Name,
Peers: peers, Peers: peers,
Issued: eg.Issued, Issued: existingGroup.Issued,
IntegrationReference: eg.IntegrationReference, IntegrationReference: existingGroup.IntegrationReference,
} }
if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil { if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group); err != nil {
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err) log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group)) accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
} }
// CreateGroup handles group creation request // CreateGroup handles group creation request
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -160,24 +174,29 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
Issued: nbgroup.GroupIssuedAPI, Issued: nbgroup.GroupIssuedAPI,
} }
err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group) err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group)) accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
} }
// DeleteGroup handles group deletion request // DeleteGroup handles group deletion request
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
aID := account.Id
groupID := mux.Vars(r)["groupId"] groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 { if len(groupID) == 0 {
@ -185,18 +204,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
return return
} }
allGroup, err := account.GetGroupAll() err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if allGroup.ID == groupID {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
return
}
err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID)
if err != nil { if err != nil {
_, ok := err.(*server.GroupLinkError) _, ok := err.(*server.GroupLinkError)
if ok { if ok {
@ -213,34 +221,39 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
// GetGroup returns a group // GetGroup returns a group
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
switch r.Method { accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
case http.MethodGet: if err != nil {
groupID := mux.Vars(r)["groupId"] util.WriteError(r.Context(), err, w)
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group))
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w)
return return
} }
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group))
} }
func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group { func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group {
peersMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peersMap[peer.ID] = peer
}
cache := make(map[string]api.PeerMinimum) cache := make(map[string]api.PeerMinimum)
gr := api.Group{ gr := api.Group{
Id: group.ID, Id: group.ID,
@ -251,7 +264,7 @@ func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group {
for _, pid := range group.Peers { for _, pid := range group.Peers {
_, ok := cache[pid] _, ok := cache[pid]
if !ok { if !ok {
peer, ok := account.Peers[pid] peer, ok := peersMap[pid]
if !ok { if !ok {
continue continue
} }

View File

@ -36,14 +36,14 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetAllNameservers returns the list of nameserver groups for the account // GetAllNameservers returns the list of nameserver groups for the account
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) { func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
log.WithContext(r.Context()).Error(err) log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id) nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -60,7 +60,7 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
// CreateNameserverGroup handles nameserver group creation request // CreateNameserverGroup handles nameserver group creation request
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -79,7 +79,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
return return
} }
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled) nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -93,7 +93,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID // UpdateNameserverGroup handles update to a nameserver group identified by a given ID
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -130,7 +130,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
SearchDomainsEnabled: req.SearchDomainsEnabled, SearchDomainsEnabled: req.SearchDomainsEnabled,
} }
err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup) err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -144,7 +144,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
// DeleteNameserverGroup handles nameserver group deletion request // DeleteNameserverGroup handles nameserver group deletion request
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -156,7 +156,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
return return
} }
err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id) err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -168,7 +168,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
// GetNameserverGroup handles a nameserver group Get request identified by ID // GetNameserverGroup handles a nameserver group Get request identified by ID
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
log.WithContext(r.Context()).Error(err) log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -181,7 +181,7 @@ func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.R
return return
} }
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID) nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@ -34,20 +34,20 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user // GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
userID := vars["userId"] targetUserID := vars["userId"]
if len(userID) == 0 { if len(userID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return return
} }
pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID) pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -64,7 +64,7 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
// GetToken is HTTP GET handler that returns a personal access token for the given user // GetToken is HTTP GET handler that returns a personal access token for the given user
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -83,7 +83,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
return return
} }
pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -95,7 +95,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
// CreateToken is HTTP POST handler that creates a personal access token for the given user // CreateToken is HTTP POST handler that creates a personal access token for the given user
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -115,7 +115,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
return return
} }
pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn) pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -127,7 +127,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user // DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -146,7 +146,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
return return
} }
err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@ -74,7 +74,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
} }
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
req := &api.PeerRequest{} req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
@ -96,7 +96,7 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
} }
} }
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update) peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(ctx, err, w)
return return
@ -130,7 +130,7 @@ func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string,
// HandlePeer handles all peer requests for GET, PUT and DELETE operations // HandlePeer handles all peer requests for GET, PUT and DELETE operations
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -144,13 +144,20 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
switch r.Method { switch r.Method {
case http.MethodDelete: case http.MethodDelete:
h.deletePeer(r.Context(), account.Id, user.Id, peerID, w) h.deletePeer(r.Context(), accountID, userID, peerID, w)
return return
case http.MethodPut: case http.MethodGet, http.MethodPut:
h.updatePeer(r.Context(), account, user, peerID, w, r) account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
return if err != nil {
case http.MethodGet: util.WriteError(r.Context(), err, w)
h.getPeer(r.Context(), account, peerID, user.Id, w) return
}
if r.Method == http.MethodGet {
h.getPeer(r.Context(), account, peerID, userID, w)
} else {
h.updatePeer(r.Context(), account, userID, peerID, w, r)
}
return return
default: default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
@ -159,19 +166,14 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
// GetAllPeers returns a list of all peers associated with a provided account // GetAllPeers returns a list of all peers associated with a provided account
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
return
}
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id) account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -179,8 +181,8 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
respBody := make([]*api.PeerBatch, 0, len(peers)) respBody := make([]*api.PeerBatch, 0, len(account.Peers))
for _, peer := range peers { for _, peer := range account.Peers {
peerToReturn, err := h.checkPeerStatus(peer) peerToReturn, err := h.checkPeerStatus(peer)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
@ -214,7 +216,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -227,6 +229,18 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
return return
} }
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := account.FindUser(userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
// If the user is regular user and does not own the peer // If the user is regular user and does not own the peer
// with the given peerID return an empty list // with the given peerID return an empty list
if !user.HasAdminPower() && !user.IsServiceUser { if !user.HasAdminPower() && !user.IsServiceUser {

View File

@ -3,9 +3,11 @@ package http
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"slices"
"strconv" "strconv"
"github.com/gorilla/mux" "github.com/gorilla/mux"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid" "github.com/rs/xid"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
@ -35,21 +37,27 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *
// GetAllPolicies list for the account // GetAllPolicies list for the account
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id) listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
policies := []*api.Policy{} allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
for _, policy := range accountPolicies { if err != nil {
resp := toPolicyResponse(account, policy) util.WriteError(r.Context(), err, w)
return
}
policies := make([]*api.Policy, 0, len(listPolicies))
for _, policy := range listPolicies {
resp := toPolicyResponse(allGroups, policy)
if len(resp.Rules) == 0 { if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return return
@ -63,7 +71,7 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
// UpdatePolicy handles update to a policy identified by a given ID // UpdatePolicy handles update to a policy identified by a given ID
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -76,41 +84,35 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
return return
} }
policyIdx := -1 account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
for i, policy := range account.Policies {
if policy.ID == policyID {
policyIdx = i
break
}
}
if policyIdx < 0 {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
return
}
h.savePolicy(w, r, account, user, policyID)
}
// CreatePolicy handles policy creation request
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
h.savePolicy(w, r, account, user, "") policyIdx := slices.IndexFunc(account.Policies, func(policy *server.Policy) bool { return policy.ID == policyID })
if policyIdx < 0 {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
return
}
h.savePolicy(w, r, accountID, userID, policyID)
}
// CreatePolicy handles policy creation request
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
h.savePolicy(w, r, accountID, userID, "")
} }
// savePolicy handles policy creation and update // savePolicy handles policy creation and update
func (h *Policies) savePolicy( func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) {
w http.ResponseWriter,
r *http.Request,
account *server.Account,
user *server.User,
policyID string,
) {
var req api.PutApiPoliciesPolicyIdJSONRequestBody var req api.PutApiPoliciesPolicyIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@ -127,6 +129,8 @@ func (h *Policies) savePolicy(
return return
} }
isUpdate := policyID != ""
if policyID == "" { if policyID == "" {
policyID = xid.New().String() policyID = xid.New().String()
} }
@ -141,8 +145,8 @@ func (h *Policies) savePolicy(
pr := server.PolicyRule{ pr := server.PolicyRule{
ID: policyID, // TODO: when policy can contain multiple rules, need refactor ID: policyID, // TODO: when policy can contain multiple rules, need refactor
Name: rule.Name, Name: rule.Name,
Destinations: groupMinimumsToStrings(account, rule.Destinations), Destinations: rule.Destinations,
Sources: groupMinimumsToStrings(account, rule.Sources), Sources: rule.Sources,
Bidirectional: rule.Bidirectional, Bidirectional: rule.Bidirectional,
} }
@ -206,16 +210,18 @@ func (h *Policies) savePolicy(
policy.Rules = append(policy.Rules, &pr) policy.Rules = append(policy.Rules, &pr)
} }
if req.SourcePostureChecks != nil { if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks)
}
if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
resp := toPolicyResponse(account, &policy) allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toPolicyResponse(allGroups, &policy)
if len(resp.Rules) == 0 { if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return return
@ -227,12 +233,11 @@ func (h *Policies) savePolicy(
// DeletePolicy handles policy deletion request // DeletePolicy handles policy deletion request
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
aID := account.Id
vars := mux.Vars(r) vars := mux.Vars(r)
policyID := vars["policyId"] policyID := vars["policyId"]
@ -241,7 +246,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
return return
} }
if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil { if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
@ -252,40 +257,46 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
// GetPolicy handles a group Get request identified by ID // GetPolicy handles a group Get request identified by ID
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
switch r.Method { vars := mux.Vars(r)
case http.MethodGet: policyID := vars["policyId"]
vars := mux.Vars(r) if len(policyID) == 0 {
policyID := vars["policyId"] util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
if len(policyID) == 0 { return
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
}
policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toPolicyResponse(account, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
}
util.WriteJSONObject(r.Context(), w, resp)
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w)
} }
policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toPolicyResponse(allGroups, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
}
util.WriteJSONObject(r.Context(), w, resp)
} }
func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy { func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy {
groupsMap := make(map[string]*nbgroup.Group)
for _, group := range groups {
groupsMap[group.ID] = group
}
cache := make(map[string]api.GroupMinimum) cache := make(map[string]api.GroupMinimum)
ap := &api.Policy{ ap := &api.Policy{
Id: &policy.ID, Id: &policy.ID,
@ -306,16 +317,18 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
Protocol: api.PolicyRuleProtocol(r.Protocol), Protocol: api.PolicyRuleProtocol(r.Protocol),
Action: api.PolicyRuleAction(r.Action), Action: api.PolicyRuleAction(r.Action),
} }
if len(r.Ports) != 0 { if len(r.Ports) != 0 {
portsCopy := r.Ports portsCopy := r.Ports
rule.Ports = &portsCopy rule.Ports = &portsCopy
} }
for _, gid := range r.Sources { for _, gid := range r.Sources {
_, ok := cache[gid] _, ok := cache[gid]
if ok { if ok {
continue continue
} }
if group, ok := account.Groups[gid]; ok { if group, ok := groupsMap[gid]; ok {
minimum := api.GroupMinimum{ minimum := api.GroupMinimum{
Id: group.ID, Id: group.ID,
Name: group.Name, Name: group.Name,
@ -325,13 +338,14 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
cache[gid] = minimum cache[gid] = minimum
} }
} }
for _, gid := range r.Destinations { for _, gid := range r.Destinations {
cachedMinimum, ok := cache[gid] cachedMinimum, ok := cache[gid]
if ok { if ok {
rule.Destinations = append(rule.Destinations, cachedMinimum) rule.Destinations = append(rule.Destinations, cachedMinimum)
continue continue
} }
if group, ok := account.Groups[gid]; ok { if group, ok := groupsMap[gid]; ok {
minimum := api.GroupMinimum{ minimum := api.GroupMinimum{
Id: group.ID, Id: group.ID,
Name: group.Name, Name: group.Name,
@ -345,28 +359,3 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
} }
return ap return ap
} }
func groupMinimumsToStrings(account *server.Account, gm []string) []string {
result := make([]string, 0, len(gm))
for _, g := range gm {
if _, ok := account.Groups[g]; !ok {
continue
}
result = append(result, g)
}
return result
}
func sourcePostureChecksToStrings(account *server.Account, postureChecksIds []string) []string {
result := make([]string, 0, len(postureChecksIds))
for _, id := range postureChecksIds {
for _, postureCheck := range account.PostureChecks {
if id == postureCheck.ID {
result = append(result, id)
continue
}
}
}
return result
}

View File

@ -37,20 +37,20 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa
// GetAllPostureChecks list for the account // GetAllPostureChecks list for the account
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) { func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id) listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
postureChecks := []*api.PostureCheck{} postureChecks := make([]*api.PostureCheck, 0, len(listPostureChecks))
for _, postureCheck := range accountPostureChecks { for _, postureCheck := range listPostureChecks {
postureChecks = append(postureChecks, postureCheck.ToAPIResponse()) postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
} }
@ -60,7 +60,7 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt
// UpdatePostureCheck handles update to a posture check identified by a given ID // UpdatePostureCheck handles update to a posture check identified by a given ID
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) { func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -73,37 +73,31 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
return return
} }
postureChecksIdx := -1 _, err = p.accountManager.GetPostureChecks(r.Context(), accountID, userID, postureChecksID)
for i, postureCheck := range account.PostureChecks { if err != nil {
if postureCheck.ID == postureChecksID {
postureChecksIdx = i
break
}
}
if postureChecksIdx < 0 {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w) util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
return return
} }
p.savePostureChecks(w, r, account, user, postureChecksID) p.savePostureChecks(w, r, accountID, userID, postureChecksID)
} }
// CreatePostureCheck handles posture check creation request // CreatePostureCheck handles posture check creation request
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) { func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
p.savePostureChecks(w, r, account, user, "") p.savePostureChecks(w, r, accountID, userID, "")
} }
// GetPostureCheck handles a posture check Get request identified by ID // GetPostureCheck handles a posture check Get request identified by ID
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) { func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -116,7 +110,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
return return
} }
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id) postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -128,7 +122,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
// DeletePostureCheck handles posture check deletion request // DeletePostureCheck handles posture check deletion request
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) { func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -141,7 +135,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
return return
} }
if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil { if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
@ -150,13 +144,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
} }
// savePostureChecks handles posture checks create and update // savePostureChecks handles posture checks create and update
func (p *PostureChecksHandler) savePostureChecks( func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) {
w http.ResponseWriter,
r *http.Request,
account *server.Account,
user *server.User,
postureChecksID string,
) {
var ( var (
err error err error
req api.PostureCheckUpdate req api.PostureCheckUpdate
@ -181,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks(
return return
} }
if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil { if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }

View File

@ -43,13 +43,13 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro
// GetAllRoutes returns the list of routes for the account // GetAllRoutes returns the list of routes for the account
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id) routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -70,7 +70,7 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
// CreateRoute handles route creation request // CreateRoute handles route creation request
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -117,15 +117,9 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
peerGroupIds = *req.PeerGroups peerGroupIds = *req.PeerGroups
} }
// Do not allow non-Linux peers newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
if peer := account.GetPeer(peerId); peer != nil { req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, userID, req.KeepRoute,
if peer.Meta.GoOS != "linux" { )
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
return
}
}
newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -168,7 +162,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
// UpdateRoute handles update to a route identified by a given ID // UpdateRoute handles update to a route identified by a given ID
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -181,7 +175,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
return return
} }
_, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) _, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -204,14 +198,6 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
peerID = *req.Peer peerID = *req.Peer
} }
// do not allow non Linux peers
if peer := account.GetPeer(peerID); peer != nil {
if peer.Meta.GoOS != "linux" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
return
}
}
newRoute := &route.Route{ newRoute := &route.Route{
ID: route.ID(routeID), ID: route.ID(routeID),
NetID: route.NetID(req.NetworkId), NetID: route.NetID(req.NetworkId),
@ -247,7 +233,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
newRoute.PeerGroups = *req.PeerGroups newRoute.PeerGroups = *req.PeerGroups
} }
err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute) err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -265,7 +251,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
// DeleteRoute handles route deletion request // DeleteRoute handles route deletion request
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -277,7 +263,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
return return
} }
err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id) err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -289,7 +275,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
// GetRoute handles a route Get request identified by ID // GetRoute handles a route Get request identified by ID
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -301,7 +287,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
return return
} }
foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w) util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
return return

View File

@ -35,7 +35,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg)
// CreateSetupKey is a POST requests that creates a new SetupKey // CreateSetupKey is a POST requests that creates a new SetupKey
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) { func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -76,8 +76,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
if req.Ephemeral != nil { if req.Ephemeral != nil {
ephemeral = *req.Ephemeral ephemeral = *req.Ephemeral
} }
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn, setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn,
req.AutoGroups, req.UsageLimit, user.Id, ephemeral) req.AutoGroups, req.UsageLimit, userID, ephemeral)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -89,7 +89,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
// GetSetupKey is a GET request to get a SetupKey by ID // GetSetupKey is a GET request to get a SetupKey by ID
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -102,7 +102,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
return return
} }
key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID) key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -114,7 +114,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
// UpdateSetupKey is a PUT request to update server.SetupKey // UpdateSetupKey is a PUT request to update server.SetupKey
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) { func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -150,7 +150,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
newKey.Name = req.Name newKey.Name = req.Name
newKey.Id = keyID newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id) newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -161,13 +161,13 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
// GetAllSetupKeys is a GET request that returns a list of SetupKey // GetAllSetupKeys is a GET request that returns a list of SetupKey
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) { func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id) setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
} }
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
userID := vars["userId"] targetUserID := vars["userId"]
if len(userID) == 0 { if len(userID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return return
} }
existingUser, ok := account.Users[userID] existingUser, err := h.accountManager.GetUserByID(r.Context(), targetUserID)
if !ok { if err != nil {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w) util.WriteError(r.Context(), err, w)
return return
} }
@ -78,7 +78,7 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
return return
} }
newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{ newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{
Id: userID, Id: userID,
Role: userRole, Role: userRole,
AutoGroups: req.AutoGroups, AutoGroups: req.AutoGroups,
@ -102,7 +102,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
} }
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -115,7 +115,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
return return
} }
err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID) err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -132,7 +132,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
} }
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
name = *req.Name name = *req.Name
} }
newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{ newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{
Email: email, Email: email,
Name: name, Name: name,
Role: req.Role, Role: req.Role,
@ -184,13 +184,13 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
} }
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id) data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -231,7 +231,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
} }
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -244,7 +244,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
return return
} }
err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID) err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return

View File

@ -26,7 +26,7 @@ type MockAccountManager struct {
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error) GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
@ -48,7 +48,7 @@ type MockAccountManager struct {
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) error SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
@ -79,7 +79,7 @@ type MockAccountManager struct {
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
GetDNSDomainFunc func() string GetDNSDomainFunc func() string
@ -105,6 +105,9 @@ type MockAccountManager struct {
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error)
GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error)
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error)
} }
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
@ -190,16 +193,14 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
} }
// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface // GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface
func (am *MockAccountManager) GetAccountByUserOrAccountID( func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) {
ctx context.Context, userId, accountId, domain string, if am.GetAccountIDByUserOrAccountIdFunc != nil {
) (*server.Account, error) { return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain)
if am.GetAccountByUserOrAccountIdFunc != nil {
return am.GetAccountByUserOrAccountIdFunc(ctx, userId, accountId, domain)
} }
return nil, status.Errorf( return "", status.Errorf(
codes.Unimplemented, codes.Unimplemented,
"method GetAccountByUserOrAccountID is not implemented", "method GetAccountIDByUserOrAccountID is not implemented",
) )
} }
@ -377,9 +378,9 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
} }
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface // SavePolicy mock implementation of SavePolicy from server.AccountManager interface
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) error { func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error {
if am.SavePolicyFunc != nil { if am.SavePolicyFunc != nil {
return am.SavePolicyFunc(ctx, accountID, userID, policy) return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate)
} }
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
} }
@ -601,12 +602,12 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
} }
// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface // GetAccountIDFromToken mocks GetAccountIDFromToken of the AccountManager interface
func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { func (am *MockAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
if am.GetAccountFromTokenFunc != nil { if am.GetAccountIDFromTokenFunc != nil {
return am.GetAccountFromTokenFunc(ctx, claims) return am.GetAccountIDFromTokenFunc(ctx, claims)
} }
return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented") return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromToken is not implemented")
} }
func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
@ -800,3 +801,26 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe
} }
return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented") return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented")
} }
// GetAccountByID mocks GetAccountByID of the AccountManager interface
func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) {
if am.GetAccountByIDFunc != nil {
return am.GetAccountByIDFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByID is not implemented")
}
// GetUserByID mocks GetUserByID of the AccountManager interface
func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) {
if am.GetUserByIDFunc != nil {
return am.GetUserByIDFunc(ctx, id)
}
return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented")
}
func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
if am.GetAccountSettingsFunc != nil {
return am.GetAccountSettingsFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented")
}

View File

@ -3,6 +3,7 @@ package server
import ( import (
"context" "context"
_ "embed" _ "embed"
"slices"
"strconv" "strconv"
"strings" "strings"
@ -341,7 +342,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
} }
// SavePolicy in the store // SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error { func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
@ -350,7 +351,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return err return err
} }
exists := am.savePolicy(account, policy) if err = am.savePolicy(account, policy, isUpdate); err != nil {
return err
}
account.Network.IncSerial() account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil { if err = am.Store.SaveAccount(ctx, account); err != nil {
@ -358,7 +361,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
} }
action := activity.PolicyAdded action := activity.PolicyAdded
if exists { if isUpdate {
action = activity.PolicyUpdated action = activity.PolicyUpdated
} }
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
@ -434,18 +437,34 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string)
return policy, nil return policy, nil
} }
func (am *DefaultAccountManager) savePolicy(account *Account, policy *Policy) (exists bool) { // savePolicy saves or updates a policy in the given account.
for i, p := range account.Policies { // If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
if p.ID == policy.ID { func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error {
account.Policies[i] = policy for index, rule := range policyToSave.Rules {
exists = true rule.Sources = filterValidGroupIDs(account, rule.Sources)
break rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
policyToSave.Rules[index] = rule
}
if policyToSave.SourcePostureChecks != nil {
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
}
if isUpdate {
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
if policyIdx < 0 {
return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
} }
// Update the existing policy
account.Policies[policyIdx] = policyToSave
return nil
} }
if !exists {
account.Policies = append(account.Policies, policy) // Add the new policy to the account
} account.Policies = append(account.Policies, policyToSave)
return
return nil
} }
func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
@ -560,3 +579,29 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
} }
return nil return nil
} }
// filterValidPostureChecks filters and returns the posture check IDs from the given list
// that are valid within the provided account.
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string {
result := make([]string, 0, len(postureChecksIds))
for _, id := range postureChecksIds {
for _, postureCheck := range account.PostureChecks {
if id == postureCheck.ID {
result = append(result, id)
continue
}
}
}
return result
}
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map.
func filterValidGroupIDs(account *Account, groupIDs []string) []string {
result := make([]string, 0, len(groupIDs))
for _, groupID := range groupIDs {
if _, exists := account.Groups[groupID]; exists {
result = append(result, groupID)
}
}
return result
}

View File

@ -134,6 +134,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, err return nil, err
} }
// Do not allow non-Linux peers
if peer := account.GetPeer(peerID); peer != nil {
if peer.Meta.GoOS != "linux" {
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
}
if len(domains) > 0 && prefix.IsValid() { if len(domains) > 0 && prefix.IsValid() {
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
} }
@ -234,6 +241,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err return err
} }
// Do not allow non-Linux peers
if peer := account.GetPeer(routeToSave.Peer); peer != nil {
if peer.Meta.GoOS != "linux" {
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
}
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
} }

View File

@ -36,6 +36,7 @@ const (
idQueryCondition = "id = ?" idQueryCondition = "id = ?"
keyQueryCondition = "key = ?" keyQueryCondition = "key = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?" accountAndIDQueryCondition = "account_id = ? and id = ?"
accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found" peerNotFoundFMT = "peer %s not found"
) )
@ -500,7 +501,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group var groups []*nbgroup.Group
result := s.db.Find(&groups, idQueryCondition, accountID) result := s.db.Find(&groups, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")

View File

@ -357,26 +357,35 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
return newUser.ToUserInfo(idpUser, account.Settings) return newUser.ToUserInfo(idpUser, account.Settings)
} }
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) {
return am.Store.GetUserByUserID(ctx, LockingStrengthShare, id)
}
// GetUser looks up a user by provided authorization claims. // GetUser looks up a user by provided authorization claims.
// It will also create an account if didn't exist for this user before. // It will also create an account if didn't exist for this user before.
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) { func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {
account, user, err := am.GetAccountFromToken(ctx, claims) accountID, userID, err := am.GetAccountIDFromToken(ctx, claims)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get account with token claims %v", err) return nil, fmt.Errorf("failed to get account with token claims %v", err)
} }
// this code should be outside of the am.GetAccountFromToken(claims) because this method is called also by the gRPC user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
// this code should be outside of the am.GetAccountIDFromToken(claims) because this method is called also by the gRPC
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event. // server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
newLogin := user.LastDashboardLoginChanged(claims.LastLogin) newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin) err = am.Store.SaveUserLastLogin(ctx, accountID, userID, claims.LastLogin)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed saving user last login: %v", err) log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
} }
if newLogin { if newLogin {
meta := map[string]any{"timestamp": claims.LastLogin} meta := map[string]any{"timestamp": claims.LastLogin}
am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.DashboardLogin, meta) am.StoreEvent(ctx, claims.UserId, claims.UserId, accountID, activity.DashboardLogin, meta)
} }
return user, nil return user, nil