mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-14 19:00:50 +01:00
refactor handlers to use GetAccountIDFromToken
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
26dd045da5
commit
8f98adddf6
@ -75,12 +75,14 @@ type AccountManager interface {
|
|||||||
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
|
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
|
||||||
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
|
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
|
||||||
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
|
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
|
||||||
|
GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
|
||||||
GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error)
|
GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error)
|
||||||
GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||||
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||||
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
|
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||||
DeleteAccount(ctx context.Context, accountID, userID string) error
|
DeleteAccount(ctx context.Context, accountID, userID string) error
|
||||||
MarkPATUsed(ctx context.Context, tokenID string) error
|
MarkPATUsed(ctx context.Context, tokenID string) error
|
||||||
|
GetUserByID(ctx context.Context, id string) (*User, error)
|
||||||
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
|
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
|
||||||
ListUsers(ctx context.Context, accountID string) ([]*User, error)
|
ListUsers(ctx context.Context, accountID string) ([]*User, error)
|
||||||
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||||
@ -107,7 +109,7 @@ type AccountManager interface {
|
|||||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||||
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||||
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
|
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
|
||||||
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error
|
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
|
||||||
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
||||||
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
|
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
|
||||||
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||||
@ -145,6 +147,7 @@ type AccountManager interface {
|
|||||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||||
|
GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaultAccountManager struct {
|
type DefaultAccountManager struct {
|
||||||
@ -1739,10 +1742,27 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st
|
|||||||
return account, user, pat, nil
|
return account, user, pat, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountFromToken returns an account associated with this token.
|
// GetAccountByID returns an account associated with this account ID.
|
||||||
func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
|
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) {
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
|
||||||
|
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
|
||||||
|
}
|
||||||
|
|
||||||
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
return am.Store.GetAccount(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountIDFromToken returns an account ID associated with this token.
|
||||||
|
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
if claims.UserId == "" {
|
if claims.UserId == "" {
|
||||||
return nil, nil, fmt.Errorf("user ID is empty")
|
return "", "", fmt.Errorf("user ID is empty")
|
||||||
}
|
}
|
||||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||||
// This section is mostly related to self-hosted installations.
|
// This section is mostly related to self-hosted installations.
|
||||||
@ -1754,28 +1774,27 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
|||||||
|
|
||||||
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims)
|
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// this is not really possible because we got an account by user ID
|
// this is not really possible because we got an account by user ID
|
||||||
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsServiceUser && claims.Invited {
|
if !user.IsServiceUser && claims.Invited {
|
||||||
err = am.redeemInvite(ctx, accountID, user.Id)
|
err = am.redeemInvite(ctx, accountID, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return "", "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil {
|
if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil {
|
||||||
return nil, nil, err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: return account id, user id and error
|
return accountID, user.Id, nil
|
||||||
return &Account{Id: accountID}, user, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||||
@ -2049,12 +2068,12 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
|
|||||||
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
|
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
|
||||||
// group propagation and set the list of groups with access permissions.
|
// group propagation and set the list of groups with access permissions.
|
||||||
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||||
account, _, err := am.GetAccountFromToken(ctx, claims)
|
accountID, _, err := am.GetAccountIDFromToken(ctx, claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, account.Id)
|
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -2133,6 +2152,19 @@ func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Stor
|
|||||||
return newLabel, nil
|
return newLabel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) {
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
|
||||||
|
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
|
||||||
|
}
|
||||||
|
|
||||||
|
return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
// addAllGroup to account object if it doesn't exist
|
// addAllGroup to account object if it doesn't exist
|
||||||
func addAllGroup(account *Account) error {
|
func addAllGroup(account *Account) error {
|
||||||
if len(account.Groups) == 0 {
|
if len(account.Groups) == 0 {
|
||||||
|
@ -27,26 +27,15 @@ func (e *GroupLinkError) Error() string {
|
|||||||
|
|
||||||
// GetGroup object of the peers
|
// GetGroup object of the peers
|
||||||
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
|
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
groups, err := am.GetAllGroups(ctx, accountID, userID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
for _, group := range groups {
|
||||||
if err != nil {
|
if group.ID == groupID {
|
||||||
return nil, err
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
|
||||||
}
|
|
||||||
|
|
||||||
group, ok := account.Groups[groupID]
|
|
||||||
if ok {
|
|
||||||
return group, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||||
@ -54,43 +43,32 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
|
|||||||
|
|
||||||
// GetAllGroups returns all groups in an account
|
// GetAllGroups returns all groups in an account
|
||||||
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
|
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
|
if !user.HasAdminPower() && !user.IsServiceUser && settings.RegularUsersViewBlocked {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
|
||||||
}
|
}
|
||||||
|
|
||||||
groups := make([]*nbgroup.Group, 0, len(account.Groups))
|
return am.Store.GetAccountGroups(ctx, accountID)
|
||||||
for _, item := range account.Groups {
|
|
||||||
groups = append(groups, item)
|
|
||||||
}
|
|
||||||
|
|
||||||
return groups, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
groups, err := am.Store.GetAccountGroups(ctx, accountID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
matchingGroups := make([]*nbgroup.Group, 0)
|
matchingGroups := make([]*nbgroup.Group, 0)
|
||||||
for _, group := range account.Groups {
|
for _, group := range groups {
|
||||||
if group.Name == groupName {
|
if group.Name == groupName {
|
||||||
matchingGroups = append(matchingGroups, group)
|
matchingGroups = append(matchingGroups, group)
|
||||||
}
|
}
|
||||||
@ -262,6 +240,15 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
allGroup, err := account.GetGroupAll()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if allGroup.ID == groupID {
|
||||||
|
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||||
|
}
|
||||||
|
|
||||||
if err = validateDeleteGroup(account, group, userId); err != nil {
|
if err = validateDeleteGroup(account, group, userId); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -262,7 +262,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
|
|||||||
}
|
}
|
||||||
claims := s.jwtClaimsExtractor.FromToken(token)
|
claims := s.jwtClaimsExtractor.FromToken(token)
|
||||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||||
_, _, err = s.accountManager.GetAccountFromToken(ctx, claims)
|
_, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -35,25 +35,26 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
|
|||||||
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
|
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
|
||||||
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
|
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID)
|
||||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toAccountResponse(account)
|
resp := toAccountResponse(accountID, settings)
|
||||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
||||||
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
|
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
_, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
_, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -96,24 +97,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
|
|||||||
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
|
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings)
|
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toAccountResponse(updatedAccount)
|
resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings)
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, &resp)
|
util.WriteJSONObject(r.Context(), w, &resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteAccount is a HTTP DELETE handler to delete an account
|
// DeleteAccount is a HTTP DELETE handler to delete an account
|
||||||
func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodDelete {
|
|
||||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
targetAccountID := vars["accountId"]
|
targetAccountID := vars["accountId"]
|
||||||
@ -131,28 +127,28 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
|
|||||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func toAccountResponse(account *server.Account) *api.Account {
|
func toAccountResponse(accountID string, settings *server.Settings) *api.Account {
|
||||||
jwtAllowGroups := account.Settings.JWTAllowGroups
|
jwtAllowGroups := settings.JWTAllowGroups
|
||||||
if jwtAllowGroups == nil {
|
if jwtAllowGroups == nil {
|
||||||
jwtAllowGroups = []string{}
|
jwtAllowGroups = []string{}
|
||||||
}
|
}
|
||||||
|
|
||||||
settings := api.AccountSettings{
|
apiSettings := api.AccountSettings{
|
||||||
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
|
PeerLoginExpiration: int(settings.PeerLoginExpiration.Seconds()),
|
||||||
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
|
PeerLoginExpirationEnabled: settings.PeerLoginExpirationEnabled,
|
||||||
GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled,
|
GroupsPropagationEnabled: &settings.GroupsPropagationEnabled,
|
||||||
JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled,
|
JwtGroupsEnabled: &settings.JWTGroupsEnabled,
|
||||||
JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName,
|
JwtGroupsClaimName: &settings.JWTGroupsClaimName,
|
||||||
JwtAllowGroups: &jwtAllowGroups,
|
JwtAllowGroups: &jwtAllowGroups,
|
||||||
RegularUsersViewBlocked: account.Settings.RegularUsersViewBlocked,
|
RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Settings.Extra != nil {
|
if settings.Extra != nil {
|
||||||
settings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &account.Settings.Extra.PeerApprovalEnabled}
|
apiSettings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &settings.Extra.PeerApprovalEnabled}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.Account{
|
return &api.Account{
|
||||||
Id: account.Id,
|
Id: accountID,
|
||||||
Settings: settings,
|
Settings: apiSettings,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -32,14 +32,14 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
|
|||||||
// GetDNSSettings returns the DNS settings for the account
|
// GetDNSSettings returns the DNS settings for the account
|
||||||
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
|
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Error(err)
|
log.WithContext(r.Context()).Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id)
|
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -55,7 +55,7 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
|
|||||||
// UpdateDNSSettings handles update to DNS settings of an account
|
// UpdateDNSSettings handles update to DNS settings of an account
|
||||||
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
|
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -72,7 +72,7 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
|
|||||||
DisabledManagementGroups: req.DisabledManagementGroups,
|
DisabledManagementGroups: req.DisabledManagementGroups,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings)
|
err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
@ -34,14 +34,14 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev
|
|||||||
// GetAllEvents list of the given account
|
// GetAllEvents list of the given account
|
||||||
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Error(err)
|
log.WithContext(r.Context()).Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id)
|
accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -51,7 +51,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
|||||||
events[i] = toEventResponse(e)
|
events[i] = toEventResponse(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id)
|
err = h.fillEventsWithUserInfo(r.Context(), events, accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
@ -98,7 +98,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
|
|||||||
|
|
||||||
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
|
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
|
||||||
claims := l.claimsExtractor.FromRequestContext(r)
|
claims := l.claimsExtractor.FromRequestContext(r)
|
||||||
_, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims)
|
_, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := l.accountManager.GetUserByID(r.Context(), userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
@ -35,14 +36,20 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr
|
|||||||
// GetAllGroups list for the account
|
// GetAllGroups list for the account
|
||||||
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Error(err)
|
log.WithContext(r.Context()).Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id)
|
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -50,7 +57,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
groupsResponse := make([]*api.Group, 0, len(groups))
|
groupsResponse := make([]*api.Group, 0, len(groups))
|
||||||
for _, group := range groups {
|
for _, group := range groups {
|
||||||
groupsResponse = append(groupsResponse, toGroupResponse(account, group))
|
groupsResponse = append(groupsResponse, toGroupResponse(accountPeers, group))
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, groupsResponse)
|
util.WriteJSONObject(r.Context(), w, groupsResponse)
|
||||||
@ -59,7 +66,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
|||||||
// UpdateGroup handles update to a group identified by a given ID
|
// UpdateGroup handles update to a group identified by a given ID
|
||||||
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -76,17 +83,18 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
eg, ok := account.Groups[groupID]
|
existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
|
||||||
if !ok {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
allGroup, err := account.GetGroupAll()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if allGroup.ID == groupID {
|
if allGroup.ID == groupID {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
|
||||||
return
|
return
|
||||||
@ -114,23 +122,29 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
ID: groupID,
|
ID: groupID,
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Peers: peers,
|
Peers: peers,
|
||||||
Issued: eg.Issued,
|
Issued: existingGroup.Issued,
|
||||||
IntegrationReference: eg.IntegrationReference,
|
IntegrationReference: existingGroup.IntegrationReference,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil {
|
if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group); err != nil {
|
||||||
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
|
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
|
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateGroup handles group creation request
|
// CreateGroup handles group creation request
|
||||||
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -160,24 +174,29 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
Issued: nbgroup.GroupIssuedAPI,
|
Issued: nbgroup.GroupIssuedAPI,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group)
|
err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
|
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteGroup handles group deletion request
|
// DeleteGroup handles group deletion request
|
||||||
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
aID := account.Id
|
|
||||||
|
|
||||||
groupID := mux.Vars(r)["groupId"]
|
groupID := mux.Vars(r)["groupId"]
|
||||||
if len(groupID) == 0 {
|
if len(groupID) == 0 {
|
||||||
@ -185,18 +204,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
allGroup, err := account.GetGroupAll()
|
err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID)
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if allGroup.ID == groupID {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_, ok := err.(*server.GroupLinkError)
|
_, ok := err.(*server.GroupLinkError)
|
||||||
if ok {
|
if ok {
|
||||||
@ -213,34 +221,39 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
// GetGroup returns a group
|
// GetGroup returns a group
|
||||||
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
groupID := mux.Vars(r)["groupId"]
|
||||||
|
if len(groupID) == 0 {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch r.Method {
|
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||||
case http.MethodGet:
|
if err != nil {
|
||||||
groupID := mux.Vars(r)["groupId"]
|
util.WriteError(r.Context(), err, w)
|
||||||
if len(groupID) == 0 {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group))
|
|
||||||
default:
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group {
|
func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group {
|
||||||
|
peersMap := make(map[string]*nbpeer.Peer, len(peers))
|
||||||
|
for _, peer := range peers {
|
||||||
|
peersMap[peer.ID] = peer
|
||||||
|
}
|
||||||
|
|
||||||
cache := make(map[string]api.PeerMinimum)
|
cache := make(map[string]api.PeerMinimum)
|
||||||
gr := api.Group{
|
gr := api.Group{
|
||||||
Id: group.ID,
|
Id: group.ID,
|
||||||
@ -251,7 +264,7 @@ func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group {
|
|||||||
for _, pid := range group.Peers {
|
for _, pid := range group.Peers {
|
||||||
_, ok := cache[pid]
|
_, ok := cache[pid]
|
||||||
if !ok {
|
if !ok {
|
||||||
peer, ok := account.Peers[pid]
|
peer, ok := peersMap[pid]
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -36,14 +36,14 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg
|
|||||||
// GetAllNameservers returns the list of nameserver groups for the account
|
// GetAllNameservers returns the list of nameserver groups for the account
|
||||||
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
|
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Error(err)
|
log.WithContext(r.Context()).Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id)
|
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -60,7 +60,7 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
|
|||||||
// CreateNameserverGroup handles nameserver group creation request
|
// CreateNameserverGroup handles nameserver group creation request
|
||||||
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -79,7 +79,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled)
|
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -93,7 +93,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID
|
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID
|
||||||
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -130,7 +130,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
SearchDomainsEnabled: req.SearchDomainsEnabled,
|
SearchDomainsEnabled: req.SearchDomainsEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup)
|
err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -144,7 +144,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
// DeleteNameserverGroup handles nameserver group deletion request
|
// DeleteNameserverGroup handles nameserver group deletion request
|
||||||
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -156,7 +156,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id)
|
err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -168,7 +168,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
// GetNameserverGroup handles a nameserver group Get request identified by ID
|
// GetNameserverGroup handles a nameserver group Get request identified by ID
|
||||||
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Error(err)
|
log.WithContext(r.Context()).Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
@ -181,7 +181,7 @@ func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.R
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID)
|
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
@ -34,20 +34,20 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH
|
|||||||
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
||||||
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
|
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
userID := vars["userId"]
|
targetUserID := vars["userId"]
|
||||||
if len(userID) == 0 {
|
if len(userID) == 0 {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID)
|
pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -64,7 +64,7 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
|
|||||||
// GetToken is HTTP GET handler that returns a personal access token for the given user
|
// GetToken is HTTP GET handler that returns a personal access token for the given user
|
||||||
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -83,7 +83,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
|
pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -95,7 +95,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
// CreateToken is HTTP POST handler that creates a personal access token for the given user
|
// CreateToken is HTTP POST handler that creates a personal access token for the given user
|
||||||
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -115,7 +115,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn)
|
pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -127,7 +127,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
||||||
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
|
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -146,7 +146,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
|
err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
@ -74,7 +74,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
|
|||||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
|
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
|
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||||
req := &api.PeerRequest{}
|
req := &api.PeerRequest{}
|
||||||
err := json.NewDecoder(r.Body).Decode(&req)
|
err := json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -96,7 +96,7 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update)
|
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(ctx, err, w)
|
util.WriteError(ctx, err, w)
|
||||||
return
|
return
|
||||||
@ -130,7 +130,7 @@ func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string,
|
|||||||
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
|
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
|
||||||
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -144,13 +144,20 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case http.MethodDelete:
|
case http.MethodDelete:
|
||||||
h.deletePeer(r.Context(), account.Id, user.Id, peerID, w)
|
h.deletePeer(r.Context(), accountID, userID, peerID, w)
|
||||||
return
|
return
|
||||||
case http.MethodPut:
|
case http.MethodGet, http.MethodPut:
|
||||||
h.updatePeer(r.Context(), account, user, peerID, w, r)
|
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||||
return
|
if err != nil {
|
||||||
case http.MethodGet:
|
util.WriteError(r.Context(), err, w)
|
||||||
h.getPeer(r.Context(), account, peerID, user.Id, w)
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Method == http.MethodGet {
|
||||||
|
h.getPeer(r.Context(), account, peerID, userID, w)
|
||||||
|
} else {
|
||||||
|
h.updatePeer(r.Context(), account, userID, peerID, w, r)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||||
@ -159,19 +166,14 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// GetAllPeers returns a list of all peers associated with a provided account
|
// GetAllPeers returns a list of all peers associated with a provided account
|
||||||
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodGet {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id)
|
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -179,8 +181,8 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
dnsDomain := h.accountManager.GetDNSDomain()
|
dnsDomain := h.accountManager.GetDNSDomain()
|
||||||
|
|
||||||
respBody := make([]*api.PeerBatch, 0, len(peers))
|
respBody := make([]*api.PeerBatch, 0, len(account.Peers))
|
||||||
for _, peer := range peers {
|
for _, peer := range account.Peers {
|
||||||
peerToReturn, err := h.checkPeerStatus(peer)
|
peerToReturn, err := h.checkPeerStatus(peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
@ -214,7 +216,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
|
|||||||
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
|
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
|
||||||
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -227,6 +229,18 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := account.FindUser(userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// If the user is regular user and does not own the peer
|
// If the user is regular user and does not own the peer
|
||||||
// with the given peerID return an empty list
|
// with the given peerID return an empty list
|
||||||
if !user.HasAdminPower() && !user.IsServiceUser {
|
if !user.HasAdminPower() && !user.IsServiceUser {
|
||||||
|
@ -3,9 +3,11 @@ package http
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
@ -35,21 +37,27 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *
|
|||||||
// GetAllPolicies list for the account
|
// GetAllPolicies list for the account
|
||||||
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
|
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id)
|
listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
policies := []*api.Policy{}
|
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||||
for _, policy := range accountPolicies {
|
if err != nil {
|
||||||
resp := toPolicyResponse(account, policy)
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
policies := make([]*api.Policy, 0, len(listPolicies))
|
||||||
|
for _, policy := range listPolicies {
|
||||||
|
resp := toPolicyResponse(allGroups, policy)
|
||||||
if len(resp.Rules) == 0 {
|
if len(resp.Rules) == 0 {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||||
return
|
return
|
||||||
@ -63,7 +71,7 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
|
|||||||
// UpdatePolicy handles update to a policy identified by a given ID
|
// UpdatePolicy handles update to a policy identified by a given ID
|
||||||
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -76,41 +84,35 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
policyIdx := -1
|
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
|
||||||
for i, policy := range account.Policies {
|
|
||||||
if policy.ID == policyID {
|
|
||||||
policyIdx = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if policyIdx < 0 {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
h.savePolicy(w, r, account, user, policyID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreatePolicy handles policy creation request
|
|
||||||
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
h.savePolicy(w, r, account, user, "")
|
policyIdx := slices.IndexFunc(account.Policies, func(policy *server.Policy) bool { return policy.ID == policyID })
|
||||||
|
if policyIdx < 0 {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.savePolicy(w, r, accountID, userID, policyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatePolicy handles policy creation request
|
||||||
|
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.savePolicy(w, r, accountID, userID, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// savePolicy handles policy creation and update
|
// savePolicy handles policy creation and update
|
||||||
func (h *Policies) savePolicy(
|
func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) {
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
account *server.Account,
|
|
||||||
user *server.User,
|
|
||||||
policyID string,
|
|
||||||
) {
|
|
||||||
var req api.PutApiPoliciesPolicyIdJSONRequestBody
|
var req api.PutApiPoliciesPolicyIdJSONRequestBody
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
@ -127,6 +129,8 @@ func (h *Policies) savePolicy(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
isUpdate := policyID != ""
|
||||||
|
|
||||||
if policyID == "" {
|
if policyID == "" {
|
||||||
policyID = xid.New().String()
|
policyID = xid.New().String()
|
||||||
}
|
}
|
||||||
@ -141,8 +145,8 @@ func (h *Policies) savePolicy(
|
|||||||
pr := server.PolicyRule{
|
pr := server.PolicyRule{
|
||||||
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
|
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
|
||||||
Name: rule.Name,
|
Name: rule.Name,
|
||||||
Destinations: groupMinimumsToStrings(account, rule.Destinations),
|
Destinations: rule.Destinations,
|
||||||
Sources: groupMinimumsToStrings(account, rule.Sources),
|
Sources: rule.Sources,
|
||||||
Bidirectional: rule.Bidirectional,
|
Bidirectional: rule.Bidirectional,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -206,16 +210,18 @@ func (h *Policies) savePolicy(
|
|||||||
policy.Rules = append(policy.Rules, &pr)
|
policy.Rules = append(policy.Rules, &pr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.SourcePostureChecks != nil {
|
if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
|
||||||
policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toPolicyResponse(account, &policy)
|
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := toPolicyResponse(allGroups, &policy)
|
||||||
if len(resp.Rules) == 0 {
|
if len(resp.Rules) == 0 {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||||
return
|
return
|
||||||
@ -227,12 +233,11 @@ func (h *Policies) savePolicy(
|
|||||||
// DeletePolicy handles policy deletion request
|
// DeletePolicy handles policy deletion request
|
||||||
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
aID := account.Id
|
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
policyID := vars["policyId"]
|
policyID := vars["policyId"]
|
||||||
@ -241,7 +246,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil {
|
if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -252,40 +257,46 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
// GetPolicy handles a group Get request identified by ID
|
// GetPolicy handles a group Get request identified by ID
|
||||||
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
|
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch r.Method {
|
vars := mux.Vars(r)
|
||||||
case http.MethodGet:
|
policyID := vars["policyId"]
|
||||||
vars := mux.Vars(r)
|
if len(policyID) == 0 {
|
||||||
policyID := vars["policyId"]
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
||||||
if len(policyID) == 0 {
|
return
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id)
|
|
||||||
if err != nil {
|
|
||||||
util.WriteError(r.Context(), err, w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := toPolicyResponse(account, policy)
|
|
||||||
if len(resp.Rules) == 0 {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, resp)
|
|
||||||
default:
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := toPolicyResponse(allGroups, policy)
|
||||||
|
if len(resp.Rules) == 0 {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
util.WriteJSONObject(r.Context(), w, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy {
|
func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy {
|
||||||
|
groupsMap := make(map[string]*nbgroup.Group)
|
||||||
|
for _, group := range groups {
|
||||||
|
groupsMap[group.ID] = group
|
||||||
|
}
|
||||||
|
|
||||||
cache := make(map[string]api.GroupMinimum)
|
cache := make(map[string]api.GroupMinimum)
|
||||||
ap := &api.Policy{
|
ap := &api.Policy{
|
||||||
Id: &policy.ID,
|
Id: &policy.ID,
|
||||||
@ -306,16 +317,18 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
|
|||||||
Protocol: api.PolicyRuleProtocol(r.Protocol),
|
Protocol: api.PolicyRuleProtocol(r.Protocol),
|
||||||
Action: api.PolicyRuleAction(r.Action),
|
Action: api.PolicyRuleAction(r.Action),
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(r.Ports) != 0 {
|
if len(r.Ports) != 0 {
|
||||||
portsCopy := r.Ports
|
portsCopy := r.Ports
|
||||||
rule.Ports = &portsCopy
|
rule.Ports = &portsCopy
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, gid := range r.Sources {
|
for _, gid := range r.Sources {
|
||||||
_, ok := cache[gid]
|
_, ok := cache[gid]
|
||||||
if ok {
|
if ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if group, ok := account.Groups[gid]; ok {
|
if group, ok := groupsMap[gid]; ok {
|
||||||
minimum := api.GroupMinimum{
|
minimum := api.GroupMinimum{
|
||||||
Id: group.ID,
|
Id: group.ID,
|
||||||
Name: group.Name,
|
Name: group.Name,
|
||||||
@ -325,13 +338,14 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
|
|||||||
cache[gid] = minimum
|
cache[gid] = minimum
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, gid := range r.Destinations {
|
for _, gid := range r.Destinations {
|
||||||
cachedMinimum, ok := cache[gid]
|
cachedMinimum, ok := cache[gid]
|
||||||
if ok {
|
if ok {
|
||||||
rule.Destinations = append(rule.Destinations, cachedMinimum)
|
rule.Destinations = append(rule.Destinations, cachedMinimum)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if group, ok := account.Groups[gid]; ok {
|
if group, ok := groupsMap[gid]; ok {
|
||||||
minimum := api.GroupMinimum{
|
minimum := api.GroupMinimum{
|
||||||
Id: group.ID,
|
Id: group.ID,
|
||||||
Name: group.Name,
|
Name: group.Name,
|
||||||
@ -345,28 +359,3 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
|
|||||||
}
|
}
|
||||||
return ap
|
return ap
|
||||||
}
|
}
|
||||||
|
|
||||||
func groupMinimumsToStrings(account *server.Account, gm []string) []string {
|
|
||||||
result := make([]string, 0, len(gm))
|
|
||||||
for _, g := range gm {
|
|
||||||
if _, ok := account.Groups[g]; !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
result = append(result, g)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func sourcePostureChecksToStrings(account *server.Account, postureChecksIds []string) []string {
|
|
||||||
result := make([]string, 0, len(postureChecksIds))
|
|
||||||
for _, id := range postureChecksIds {
|
|
||||||
for _, postureCheck := range account.PostureChecks {
|
|
||||||
if id == postureCheck.ID {
|
|
||||||
result = append(result, id)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
@ -37,20 +37,20 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa
|
|||||||
// GetAllPostureChecks list for the account
|
// GetAllPostureChecks list for the account
|
||||||
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
|
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id)
|
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks := []*api.PostureCheck{}
|
postureChecks := make([]*api.PostureCheck, 0, len(listPostureChecks))
|
||||||
for _, postureCheck := range accountPostureChecks {
|
for _, postureCheck := range listPostureChecks {
|
||||||
postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
|
postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt
|
|||||||
// UpdatePostureCheck handles update to a posture check identified by a given ID
|
// UpdatePostureCheck handles update to a posture check identified by a given ID
|
||||||
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -73,37 +73,31 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecksIdx := -1
|
_, err = p.accountManager.GetPostureChecks(r.Context(), accountID, userID, postureChecksID)
|
||||||
for i, postureCheck := range account.PostureChecks {
|
if err != nil {
|
||||||
if postureCheck.ID == postureChecksID {
|
|
||||||
postureChecksIdx = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if postureChecksIdx < 0 {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
|
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.savePostureChecks(w, r, account, user, postureChecksID)
|
p.savePostureChecks(w, r, accountID, userID, postureChecksID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreatePostureCheck handles posture check creation request
|
// CreatePostureCheck handles posture check creation request
|
||||||
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.savePostureChecks(w, r, account, user, "")
|
p.savePostureChecks(w, r, accountID, userID, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPostureCheck handles a posture check Get request identified by ID
|
// GetPostureCheck handles a posture check Get request identified by ID
|
||||||
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
|
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -116,7 +110,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id)
|
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -128,7 +122,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
|
|||||||
// DeletePostureCheck handles posture check deletion request
|
// DeletePostureCheck handles posture check deletion request
|
||||||
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
|
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -141,7 +135,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil {
|
if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -150,13 +144,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
|
|||||||
}
|
}
|
||||||
|
|
||||||
// savePostureChecks handles posture checks create and update
|
// savePostureChecks handles posture checks create and update
|
||||||
func (p *PostureChecksHandler) savePostureChecks(
|
func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) {
|
||||||
w http.ResponseWriter,
|
|
||||||
r *http.Request,
|
|
||||||
account *server.Account,
|
|
||||||
user *server.User,
|
|
||||||
postureChecksID string,
|
|
||||||
) {
|
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
req api.PostureCheckUpdate
|
req api.PostureCheckUpdate
|
||||||
@ -181,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil {
|
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -43,13 +43,13 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro
|
|||||||
// GetAllRoutes returns the list of routes for the account
|
// GetAllRoutes returns the list of routes for the account
|
||||||
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
|
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id)
|
routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -70,7 +70,7 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
|
|||||||
// CreateRoute handles route creation request
|
// CreateRoute handles route creation request
|
||||||
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -117,15 +117,9 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
peerGroupIds = *req.PeerGroups
|
peerGroupIds = *req.PeerGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do not allow non-Linux peers
|
newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
|
||||||
if peer := account.GetPeer(peerId); peer != nil {
|
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, userID, req.KeepRoute,
|
||||||
if peer.Meta.GoOS != "linux" {
|
)
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -168,7 +162,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
|
|||||||
// UpdateRoute handles update to a route identified by a given ID
|
// UpdateRoute handles update to a route identified by a given ID
|
||||||
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -181,7 +175,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
|
_, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -204,14 +198,6 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
peerID = *req.Peer
|
peerID = *req.Peer
|
||||||
}
|
}
|
||||||
|
|
||||||
// do not allow non Linux peers
|
|
||||||
if peer := account.GetPeer(peerID); peer != nil {
|
|
||||||
if peer.Meta.GoOS != "linux" {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
newRoute := &route.Route{
|
newRoute := &route.Route{
|
||||||
ID: route.ID(routeID),
|
ID: route.ID(routeID),
|
||||||
NetID: route.NetID(req.NetworkId),
|
NetID: route.NetID(req.NetworkId),
|
||||||
@ -247,7 +233,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
newRoute.PeerGroups = *req.PeerGroups
|
newRoute.PeerGroups = *req.PeerGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute)
|
err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -265,7 +251,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
// DeleteRoute handles route deletion request
|
// DeleteRoute handles route deletion request
|
||||||
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -277,7 +263,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
|
err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -289,7 +275,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
// GetRoute handles a route Get request identified by ID
|
// GetRoute handles a route Get request identified by ID
|
||||||
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -301,7 +287,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
|
foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
|
util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
|
||||||
return
|
return
|
||||||
|
@ -35,7 +35,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg)
|
|||||||
// CreateSetupKey is a POST requests that creates a new SetupKey
|
// CreateSetupKey is a POST requests that creates a new SetupKey
|
||||||
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -76,8 +76,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
if req.Ephemeral != nil {
|
if req.Ephemeral != nil {
|
||||||
ephemeral = *req.Ephemeral
|
ephemeral = *req.Ephemeral
|
||||||
}
|
}
|
||||||
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
||||||
req.AutoGroups, req.UsageLimit, user.Id, ephemeral)
|
req.AutoGroups, req.UsageLimit, userID, ephemeral)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -89,7 +89,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
// GetSetupKey is a GET request to get a SetupKey by ID
|
// GetSetupKey is a GET request to get a SetupKey by ID
|
||||||
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -102,7 +102,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID)
|
key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -114,7 +114,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
|||||||
// UpdateSetupKey is a PUT request to update server.SetupKey
|
// UpdateSetupKey is a PUT request to update server.SetupKey
|
||||||
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -150,7 +150,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
newKey.Name = req.Name
|
newKey.Name = req.Name
|
||||||
newKey.Id = keyID
|
newKey.Id = keyID
|
||||||
|
|
||||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id)
|
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -161,13 +161,13 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
// GetAllSetupKeys is a GET request that returns a list of SetupKey
|
// GetAllSetupKeys is a GET request that returns a list of SetupKey
|
||||||
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id)
|
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
userID := vars["userId"]
|
targetUserID := vars["userId"]
|
||||||
if len(userID) == 0 {
|
if len(userID) == 0 {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
existingUser, ok := account.Users[userID]
|
existingUser, err := h.accountManager.GetUserByID(r.Context(), targetUserID)
|
||||||
if !ok {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{
|
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{
|
||||||
Id: userID,
|
Id: userID,
|
||||||
Role: userRole,
|
Role: userRole,
|
||||||
AutoGroups: req.AutoGroups,
|
AutoGroups: req.AutoGroups,
|
||||||
@ -102,7 +102,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -115,7 +115,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID)
|
err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -132,7 +132,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
name = *req.Name
|
name = *req.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{
|
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{
|
||||||
Email: email,
|
Email: email,
|
||||||
Name: name,
|
Name: name,
|
||||||
Role: req.Role,
|
Role: req.Role,
|
||||||
@ -184,13 +184,13 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id)
|
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -231,7 +231,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -244,7 +244,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID)
|
err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
|
@ -26,7 +26,7 @@ type MockAccountManager struct {
|
|||||||
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
|
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
|
||||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
||||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
||||||
GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error)
|
GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error)
|
||||||
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
||||||
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||||
@ -48,7 +48,7 @@ type MockAccountManager struct {
|
|||||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||||
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
||||||
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
|
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
|
||||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) error
|
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
|
||||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
||||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
||||||
@ -79,7 +79,7 @@ type MockAccountManager struct {
|
|||||||
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
||||||
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||||
CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||||
GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||||
CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||||
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
|
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
|
||||||
GetDNSDomainFunc func() string
|
GetDNSDomainFunc func() string
|
||||||
@ -105,6 +105,9 @@ type MockAccountManager struct {
|
|||||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||||
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
||||||
|
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error)
|
||||||
|
GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error)
|
||||||
|
GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||||
@ -190,16 +193,14 @@ func (am *MockAccountManager) CreateSetupKey(
|
|||||||
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface
|
// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetAccountByUserOrAccountID(
|
func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) {
|
||||||
ctx context.Context, userId, accountId, domain string,
|
if am.GetAccountIDByUserOrAccountIdFunc != nil {
|
||||||
) (*server.Account, error) {
|
return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain)
|
||||||
if am.GetAccountByUserOrAccountIdFunc != nil {
|
|
||||||
return am.GetAccountByUserOrAccountIdFunc(ctx, userId, accountId, domain)
|
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(
|
return "", status.Errorf(
|
||||||
codes.Unimplemented,
|
codes.Unimplemented,
|
||||||
"method GetAccountByUserOrAccountID is not implemented",
|
"method GetAccountIDByUserOrAccountID is not implemented",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -377,9 +378,9 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
|
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
|
||||||
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) error {
|
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error {
|
||||||
if am.SavePolicyFunc != nil {
|
if am.SavePolicyFunc != nil {
|
||||||
return am.SavePolicyFunc(ctx, accountID, userID, policy)
|
return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate)
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
|
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
|
||||||
}
|
}
|
||||||
@ -601,12 +602,12 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
|
|||||||
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface
|
// GetAccountIDFromToken mocks GetAccountIDFromToken of the AccountManager interface
|
||||||
func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
func (am *MockAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
if am.GetAccountFromTokenFunc != nil {
|
if am.GetAccountIDFromTokenFunc != nil {
|
||||||
return am.GetAccountFromTokenFunc(ctx, claims)
|
return am.GetAccountIDFromTokenFunc(ctx, claims)
|
||||||
}
|
}
|
||||||
return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
|
return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromToken is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||||
@ -800,3 +801,26 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe
|
|||||||
}
|
}
|
||||||
return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented")
|
return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAccountByID mocks GetAccountByID of the AccountManager interface
|
||||||
|
func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) {
|
||||||
|
if am.GetAccountByIDFunc != nil {
|
||||||
|
return am.GetAccountByIDFunc(ctx, accountID, userID)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByID is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserByID mocks GetUserByID of the AccountManager interface
|
||||||
|
func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) {
|
||||||
|
if am.GetUserByIDFunc != nil {
|
||||||
|
return am.GetUserByIDFunc(ctx, id)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
|
||||||
|
if am.GetAccountSettingsFunc != nil {
|
||||||
|
return am.GetAccountSettingsFunc(ctx, accountID, userID)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented")
|
||||||
|
}
|
||||||
|
@ -3,6 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -341,7 +342,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SavePolicy in the store
|
// SavePolicy in the store
|
||||||
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error {
|
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
@ -350,7 +351,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
exists := am.savePolicy(account, policy)
|
if err = am.savePolicy(account, policy, isUpdate); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
account.Network.IncSerial()
|
account.Network.IncSerial()
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||||
@ -358,7 +361,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
|||||||
}
|
}
|
||||||
|
|
||||||
action := activity.PolicyAdded
|
action := activity.PolicyAdded
|
||||||
if exists {
|
if isUpdate {
|
||||||
action = activity.PolicyUpdated
|
action = activity.PolicyUpdated
|
||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||||
@ -434,18 +437,34 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string)
|
|||||||
return policy, nil
|
return policy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) savePolicy(account *Account, policy *Policy) (exists bool) {
|
// savePolicy saves or updates a policy in the given account.
|
||||||
for i, p := range account.Policies {
|
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
|
||||||
if p.ID == policy.ID {
|
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error {
|
||||||
account.Policies[i] = policy
|
for index, rule := range policyToSave.Rules {
|
||||||
exists = true
|
rule.Sources = filterValidGroupIDs(account, rule.Sources)
|
||||||
break
|
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
|
||||||
|
policyToSave.Rules[index] = rule
|
||||||
|
}
|
||||||
|
|
||||||
|
if policyToSave.SourcePostureChecks != nil {
|
||||||
|
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isUpdate {
|
||||||
|
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
|
||||||
|
if policyIdx < 0 {
|
||||||
|
return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update the existing policy
|
||||||
|
account.Policies[policyIdx] = policyToSave
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
if !exists {
|
|
||||||
account.Policies = append(account.Policies, policy)
|
// Add the new policy to the account
|
||||||
}
|
account.Policies = append(account.Policies, policyToSave)
|
||||||
return
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
|
func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
|
||||||
@ -560,3 +579,29 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// filterValidPostureChecks filters and returns the posture check IDs from the given list
|
||||||
|
// that are valid within the provided account.
|
||||||
|
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string {
|
||||||
|
result := make([]string, 0, len(postureChecksIds))
|
||||||
|
for _, id := range postureChecksIds {
|
||||||
|
for _, postureCheck := range account.PostureChecks {
|
||||||
|
if id == postureCheck.ID {
|
||||||
|
result = append(result, id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map.
|
||||||
|
func filterValidGroupIDs(account *Account, groupIDs []string) []string {
|
||||||
|
result := make([]string, 0, len(groupIDs))
|
||||||
|
for _, groupID := range groupIDs {
|
||||||
|
if _, exists := account.Groups[groupID]; exists {
|
||||||
|
result = append(result, groupID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
@ -134,6 +134,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Do not allow non-Linux peers
|
||||||
|
if peer := account.GetPeer(peerID); peer != nil {
|
||||||
|
if peer.Meta.GoOS != "linux" {
|
||||||
|
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(domains) > 0 && prefix.IsValid() {
|
if len(domains) > 0 && prefix.IsValid() {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||||
}
|
}
|
||||||
@ -234,6 +241,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Do not allow non-Linux peers
|
||||||
|
if peer := account.GetPeer(routeToSave.Peer); peer != nil {
|
||||||
|
if peer.Meta.GoOS != "linux" {
|
||||||
|
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
|
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
|
||||||
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||||
}
|
}
|
||||||
|
@ -36,6 +36,7 @@ const (
|
|||||||
idQueryCondition = "id = ?"
|
idQueryCondition = "id = ?"
|
||||||
keyQueryCondition = "key = ?"
|
keyQueryCondition = "key = ?"
|
||||||
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
||||||
|
accountIDCondition = "account_id = ?"
|
||||||
peerNotFoundFMT = "peer %s not found"
|
peerNotFoundFMT = "peer %s not found"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -500,7 +501,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
|||||||
|
|
||||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||||
var groups []*nbgroup.Group
|
var groups []*nbgroup.Group
|
||||||
result := s.db.Find(&groups, idQueryCondition, accountID)
|
result := s.db.Find(&groups, accountIDCondition, accountID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||||
|
@ -357,26 +357,35 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
|||||||
return newUser.ToUserInfo(idpUser, account.Settings)
|
return newUser.ToUserInfo(idpUser, account.Settings)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) {
|
||||||
|
return am.Store.GetUserByUserID(ctx, LockingStrengthShare, id)
|
||||||
|
}
|
||||||
|
|
||||||
// GetUser looks up a user by provided authorization claims.
|
// GetUser looks up a user by provided authorization claims.
|
||||||
// It will also create an account if didn't exist for this user before.
|
// It will also create an account if didn't exist for this user before.
|
||||||
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {
|
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {
|
||||||
account, user, err := am.GetAccountFromToken(ctx, claims)
|
accountID, userID, err := am.GetAccountIDFromToken(ctx, claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
return nil, fmt.Errorf("failed to get account with token claims %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// this code should be outside of the am.GetAccountFromToken(claims) because this method is called also by the gRPC
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// this code should be outside of the am.GetAccountIDFromToken(claims) because this method is called also by the gRPC
|
||||||
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
|
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
|
||||||
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
|
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
|
||||||
|
|
||||||
err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin)
|
err = am.Store.SaveUserLastLogin(ctx, accountID, userID, claims.LastLogin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
|
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if newLogin {
|
if newLogin {
|
||||||
meta := map[string]any{"timestamp": claims.LastLogin}
|
meta := map[string]any{"timestamp": claims.LastLogin}
|
||||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.DashboardLogin, meta)
|
am.StoreEvent(ctx, claims.UserId, claims.UserId, accountID, activity.DashboardLogin, meta)
|
||||||
}
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
|
Loading…
Reference in New Issue
Block a user