diff --git a/management/server/account.go b/management/server/account.go index 006d3f67d..d34f493e9 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -8,12 +8,11 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" gocache "github.com/patrickmn/go-cache" "github.com/rs/xid" log "github.com/sirupsen/logrus" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "math/rand" "net" "net/netip" @@ -52,7 +51,7 @@ type AccountManager interface { SaveUser(accountID string, key *User) (*UserInfo, error) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) - GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) + GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) AccountExists(accountId string) (*bool, error) GetPeer(peerKey string) (*Peer, error) @@ -265,14 +264,14 @@ func (a *Account) FindPeerByPubKey(peerPubKey string) (*Peer, error) { } } - return nil, status.Errorf(codes.NotFound, "peer with the public key %s not found", peerPubKey) + return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", peerPubKey) } // FindUser looks for a given user in the Account or returns error if user wasn't found. func (a *Account) FindUser(userID string) (*User, error) { user := a.Users[userID] if user == nil { - return nil, Errorf(UserNotFound, "user %s not found", userID) + return nil, status.Errorf(status.NotFound, "user %s not found", userID) } return user, nil @@ -282,7 +281,7 @@ func (a *Account) FindUser(userID string) (*User, error) { func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) { key := a.SetupKeys[setupKey] if key == nil { - return nil, Errorf(SetupKeyNotFound, "setup key not found") + return nil, status.Errorf(status.NotFound, "setup key not found") } return key, nil @@ -458,14 +457,14 @@ func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, er if err == nil { log.Warnf("an account with ID already exists, retrying...") continue - } else if statusErr.Code() == codes.NotFound { + } else if statusErr.Type() == status.NotFound { return newAccountWithId(accountId, userID, domain), nil } else { return nil, err } } - return nil, status.Errorf(codes.Internal, "error while creating new account") + return nil, status.Errorf(status.Internal, "error while creating new account") } func (am *DefaultAccountManager) warmupIDPCache() error { @@ -492,7 +491,7 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, } else if userID != "" { account, err := am.GetOrCreateAccountByUser(userID, domain) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userID) + return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID) } err = am.addAccountIDToIDPAppMeta(userID, account) if err != nil { @@ -501,7 +500,7 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, return account, nil } - return nil, status.Errorf(codes.NotFound, "no valid user or account Id provided") + return nil, status.Errorf(status.NotFound, "no valid user or account Id provided") } func isNil(i idp.Manager) bool { @@ -531,11 +530,7 @@ func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account } if err != nil { - return status.Errorf( - codes.Internal, - "updating user's app metadata failed with: %v", - err, - ) + return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err) } // refresh cache to reflect the update _, err = am.refreshCache(account.Id) @@ -662,11 +657,8 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a } // updateAccountDomainAttributes updates the account domain attributes and then, saves the account -func (am *DefaultAccountManager) updateAccountDomainAttributes( - account *Account, - claims jwtclaims.AuthorizationClaims, - primaryDomain bool, -) error { +func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims, + primaryDomain bool) error { account.IsDomainPrimaryAccount = primaryDomain lowerDomain := strings.ToLower(claims.Domain) @@ -681,7 +673,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes( err := am.Store.SaveAccount(account) if err != nil { - return status.Errorf(codes.Internal, "failed saving updated account") + return err } return nil } @@ -723,10 +715,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount( // handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // otherwise it will create a new account and make it primary account for the domain. -func (am *DefaultAccountManager) handleNewUserAccount( - domainAcc *Account, - claims jwtclaims.AuthorizationClaims, -) (*Account, error) { +func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) { var ( account *Account err error @@ -738,7 +727,7 @@ func (am *DefaultAccountManager) handleNewUserAccount( account.Users[claims.UserId] = NewRegularUser(claims.UserId) err = am.Store.SaveAccount(account) if err != nil { - return nil, status.Errorf(codes.Internal, "failed saving updated account") + return nil, err } } else { account, err = am.newAccount(claims.UserId, lowerDomain) @@ -773,7 +762,7 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e } if user == nil { - return status.Errorf(codes.NotFound, "user %s not found in the IdP", userID) + return status.Errorf(status.NotFound, "user %s not found in the IdP", userID) } if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite { @@ -794,7 +783,7 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e } // GetAccountFromToken returns an account associated with this token -func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) { +func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) { if am.singleAccountMode && am.singleAccountModeDomain != "" { // This section is mostly related to self-hosted installations. @@ -806,15 +795,21 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat account, err := am.getAccountWithAuthorizationClaims(claims) if err != nil { - return nil, err + return nil, nil, err + } + + user := account.Users[claims.UserId] + if user == nil { + // this is not really possible because we got an account by user ID + return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId) } err = am.redeemInvite(account, claims.UserId) if err != nil { - return nil, err + return nil, nil, err } - return account, nil + return account, user, nil } // getAccountWithAuthorizationClaims retrievs an account using JWT Claims. @@ -857,9 +852,12 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla // We checked if the domain has a primary account already domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain) - accStatus, _ := status.FromError(err) - if accStatus.Code() != codes.OK && accStatus.Code() != codes.NotFound { - return nil, err + if err != nil { + // if NotFound we are good to continue, otherwise return error + e, ok := status.FromError(err) + if !ok || e.Type() != status.NotFound { + return nil, err + } } account, err := am.Store.GetAccountByUser(claims.UserId) @@ -869,7 +867,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla return nil, err } return account, nil - } else if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { + } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { return am.handleNewUserAccount(domainAccount, claims) } else { // other error @@ -891,7 +889,7 @@ func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error) var res bool _, err := am.Store.GetAccount(accountID) if err != nil { - if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { + if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { res = false return &res, nil } else { diff --git a/management/server/account_test.go b/management/server/account_test.go index b413d7b22..28d1e991b 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -314,7 +314,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { testCase.inputClaims.AccountId = initAccount.Id } - account, err := manager.GetAccountFromToken(testCase.inputClaims) + account, _, err := manager.GetAccountFromToken(testCase.inputClaims) require.NoError(t, err, "support function failed") verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers) verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy) diff --git a/management/server/error.go b/management/server/error.go deleted file mode 100644 index 72858d5b9..000000000 --- a/management/server/error.go +++ /dev/null @@ -1,61 +0,0 @@ -package server - -import ( - "fmt" -) - -const ( - // UserAlreadyExists indicates that user already exists - UserAlreadyExists ErrorType = iota - // AccountNotFound indicates that specified account hasn't been found - AccountNotFound - // PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled - PreconditionFailed - - // UserNotFound indicates that user wasn't found in the system (or under a given Account) - UserNotFound - - // PermissionDenied indicates that user has no permissions to view data - PermissionDenied - - // SetupKeyNotFound indicates that the setup key wasn't found in the system (or under a given Account) - SetupKeyNotFound -) - -// ErrorType is a type of the Error -type ErrorType int32 - -// Error is an internal error -type Error struct { - errorType ErrorType - message string -} - -// Type returns the Type of the error -func (e *Error) Type() ErrorType { - return e.errorType -} - -// Error is an error string -func (e *Error) Error() string { - return e.message -} - -// Errorf returns Error(errorType, fmt.Sprintf(format, a...)). -func Errorf(errorType ErrorType, format string, a ...interface{}) error { - return &Error{ - errorType: errorType, - message: fmt.Sprintf(format, a...), - } -} - -// FromError returns Error, true if the provided error is of type of Error. nil, false otherwise -func FromError(err error) (s *Error, ok bool) { - if err == nil { - return nil, true - } - if e, ok := err.(*Error); ok { - return e, true - } - return nil, false -} diff --git a/management/server/file_store.go b/management/server/file_store.go index 8206713ed..aab3eb8d8 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -1,6 +1,7 @@ package server import ( + "github.com/netbirdio/netbird/management/server/status" log "github.com/sirupsen/logrus" "os" "path/filepath" @@ -8,9 +9,6 @@ import ( "sync" "time" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "github.com/netbirdio/netbird/util" ) @@ -192,10 +190,7 @@ func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) { accountID, accountIDFound := s.PrivateDomain2AccountID[strings.ToLower(domain)] if !accountIDFound { - return nil, status.Errorf( - codes.NotFound, - "account not found: provided domain is not registered or is not private", - ) + return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") } account, err := s.getAccount(accountID) @@ -213,7 +208,7 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) { accountID, accountIDFound := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] if !accountIDFound { - return nil, status.Errorf(codes.NotFound, "account not found: provided setup key doesn't exists") + return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists") } account, err := s.getAccount(accountID) @@ -239,7 +234,7 @@ func (s *FileStore) GetAllAccounts() (all []*Account) { func (s *FileStore) getAccount(accountID string) (*Account, error) { account, accountFound := s.Accounts[accountID] if !accountFound { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, status.Errorf(status.NotFound, "account not found") } return account, nil @@ -265,7 +260,7 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) { accountID, accountIDFound := s.UserID2AccountID[userID] if !accountIDFound { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, status.Errorf(status.NotFound, "account not found") } account, err := s.getAccount(accountID) @@ -283,7 +278,7 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { accountID, accountIDFound := s.PeerKeyID2AccountID[peerKey] if !accountIDFound { - return nil, status.Errorf(codes.NotFound, "Provided peer key doesn't exists %s", peerKey) + return nil, status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey) } account, err := s.getAccount(accountID) @@ -322,7 +317,7 @@ func (s *FileStore) SavePeerStatus(accountID, peerKey string, peerStatus PeerSta peer := account.Peers[peerKey] if peer == nil { - return status.Errorf(codes.NotFound, "peer %s not found", peerKey) + return status.Errorf(status.NotFound, "peer %s not found", peerKey) } peer.Status = &peerStatus diff --git a/management/server/group.go b/management/server/group.go index cbe606463..cfc784487 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -1,9 +1,6 @@ package server -import ( - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) +import "github.com/netbirdio/netbird/management/server/status" // Group of the peers for ACL type Group struct { @@ -53,7 +50,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } group, ok := account.Groups[groupID] @@ -61,7 +58,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er return group, nil } - return nil, status.Errorf(codes.NotFound, "group with ID %s not found", groupID) + return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID) } // SaveGroup object of the peers @@ -72,7 +69,7 @@ func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error account, err := am.Store.GetAccount(accountID) if err != nil { - return status.Errorf(codes.NotFound, "account not found") + return err } account.Groups[group.ID] = group @@ -94,12 +91,12 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string, account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } groupToUpdate, ok := account.Groups[groupID] if !ok { - return nil, status.Errorf(codes.NotFound, "group %s no longer exists", groupID) + return nil, status.Errorf(status.NotFound, "group with ID %s no longer exists", groupID) } group := groupToUpdate.Copy() @@ -130,7 +127,7 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string, err = am.updateAccountPeers(account) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to update account peers") + return nil, err } return group, nil @@ -144,7 +141,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error { account, err := am.Store.GetAccount(accountID) if err != nil { - return status.Errorf(codes.NotFound, "account not found") + return err } delete(account.Groups, groupID) @@ -165,7 +162,7 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } groups := make([]*Group, 0, len(account.Groups)) @@ -184,12 +181,12 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string account, err := am.Store.GetAccount(accountID) if err != nil { - return status.Errorf(codes.NotFound, "account not found") + return err } group, ok := account.Groups[groupID] if !ok { - return status.Errorf(codes.NotFound, "group with ID %s not found", groupID) + return status.Errorf(status.NotFound, "group with ID %s not found", groupID) } add := true @@ -219,12 +216,12 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str account, err := am.Store.GetAccount(accountID) if err != nil { - return status.Errorf(codes.NotFound, "account not found") + return err } group, ok := account.Groups[groupID] if !ok { - return status.Errorf(codes.NotFound, "group with ID %s not found", groupID) + return status.Errorf(status.NotFound, "group with ID %s not found", groupID) } account.Network.IncSerial() @@ -232,7 +229,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str if itemID == peerKey { group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) if err := am.Store.SaveAccount(account); err != nil { - return status.Errorf(codes.Internal, "can't save account") + return err } } } @@ -248,12 +245,12 @@ func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*P account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, status.Errorf(status.NotFound, "account not found") } group, ok := account.Groups[groupID] if !ok { - return nil, status.Errorf(codes.NotFound, "group with ID %s not found", groupID) + return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID) } peers := make([]*Peer, 0, len(account.Groups)) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 193bad39b..e77de237a 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -14,6 +14,7 @@ import ( "github.com/golang/protobuf/ptypes/timestamp" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" + internalStatus "github.com/netbirdio/netbird/management/server/status" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc/codes" @@ -200,10 +201,6 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err) } claims := jwtclaims.ExtractClaimsWithToken(token, s.config.HttpConfig.AuthAudience) - _, err = s.accountManager.GetAccountFromToken(claims) - if err != nil { - return nil, status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) - } userID = claims.UserId } else { log.Debugln("using setup key to register peer") @@ -237,14 +234,12 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) }, }) if err != nil { - if e, ok := FromError(err); ok { + if e, ok := internalStatus.FromError(err); ok { switch e.Type() { - case PreconditionFailed: - return nil, status.Errorf(codes.FailedPrecondition, e.message) - case AccountNotFound: - case SetupKeyNotFound: - case UserNotFound: - return nil, status.Errorf(codes.NotFound, e.message) + case internalStatus.PreconditionFailed: + return nil, status.Errorf(codes.FailedPrecondition, e.Message) + case internalStatus.NotFound: + return nil, status.Errorf(codes.NotFound, e.Message) default: } } @@ -301,7 +296,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p peer, err := s.accountManager.GetPeer(peerKey.String()) if err != nil { - if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.NotFound { + if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound { // peer doesn't exist -> check if setup key was provided if loginReq.GetJwtToken() == "" && loginReq.GetSetupKey() == "" { // absent setup key or jwt -> permission denied @@ -387,7 +382,6 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol { case TCP: return proto.HostConfig_TCP default: - // mbragin: todo something better? panic(fmt.Errorf("unexpected config protocol type %v", configProto)) } } diff --git a/management/server/http/groups.go b/management/server/http/groups.go index fbf660ad8..a636da375 100644 --- a/management/server/http/groups.go +++ b/management/server/http/groups.go @@ -2,10 +2,9 @@ package http import ( "encoding/json" - "fmt" "github.com/netbirdio/netbird/management/server/http/api" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/status" "net/http" "github.com/netbirdio/netbird/management/server" @@ -33,7 +32,8 @@ func NewGroups(accountManager server.AccountManager, authAudience string) *Group // GetAllGroupsHandler list for the account func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -45,52 +45,54 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) { groups = append(groups, toGroupResponse(account, g)) } - writeJSONObject(w, groups) + util.WriteJSONObject(w, groups) } // UpdateGroupHandler handles update to a group identified by a given ID func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } vars := mux.Vars(r) groupID, ok := vars["id"] if !ok { - http.Error(w, "group ID field is missing", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "group ID field is missing"), w) return } if len(groupID) == 0 { - http.Error(w, "group ID can't be empty", http.StatusUnprocessableEntity) + util.WriteError(status.Errorf(status.InvalidArgument, "group ID can't be empty"), w) return } _, ok = account.Groups[groupID] if !ok { - http.Error(w, fmt.Sprintf("couldn't find group with ID %s", groupID), http.StatusNotFound) + util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w) return } allGroup, err := account.GetGroupAll() if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + util.WriteError(err, w) return } if allGroup.ID == groupID { - http.Error(w, "updating group ALL is not allowed", http.StatusMethodNotAllowed) + util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w) return } var req api.PutApiGroupsIdJSONRequestBody - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } if *req.Name == "" { - http.Error(w, "group name shouldn't be empty", http.StatusUnprocessableEntity) + util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w) return } @@ -102,53 +104,55 @@ func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) { if err := h.accountManager.SaveGroup(account.Id, &group); err != nil { log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } - writeJSONObject(w, toGroupResponse(account, &group)) + util.WriteJSONObject(w, toGroupResponse(account, &group)) } // PatchGroupHandler handles patch updates to a group identified by a given ID func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } vars := mux.Vars(r) groupID := vars["id"] if len(groupID) == 0 { - http.Error(w, "invalid group Id", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w) return } _, ok := account.Groups[groupID] if !ok { - http.Error(w, fmt.Sprintf("couldn't find group id %s", groupID), http.StatusNotFound) + util.WriteError(status.Errorf(status.NotFound, "couldn't find group ID %s", groupID), w) return } allGroup, err := account.GetGroupAll() if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + util.WriteError(err, w) return } if allGroup.ID == groupID { - http.Error(w, "updating group ALL is not allowed", http.StatusMethodNotAllowed) + util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w) return } var req api.PatchApiGroupsIdJSONRequestBody - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } if len(req) == 0 { - http.Error(w, "no patch instruction received", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w) return } @@ -158,13 +162,13 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) { switch patch.Path { case api.GroupPatchOperationPathName: if patch.Op != api.GroupPatchOperationOpReplace { - http.Error(w, fmt.Sprintf("Name field only accepts replace operation, got %s", patch.Op), - http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, + "name field only accepts replace operation, got %s", patch.Op), w) return } if len(patch.Value) == 0 || patch.Value[0] == "" { - http.Error(w, "Group name shouldn't be empty", http.StatusUnprocessableEntity) + util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w) return } @@ -193,53 +197,43 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) { Values: peerKeys, }) default: - http.Error(w, "invalid operation, \"%s\", for Peers field", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, + "invalid operation, \"%v\", for Peers field", patch.Op), w) return } default: - http.Error(w, "invalid patch path", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w) return } } group, err := h.accountManager.UpdateGroup(account.Id, groupID, operations) - if err != nil { - errStatus, ok := status.FromError(err) - if ok && errStatus.Code() == codes.Internal { - http.Error(w, errStatus.String(), http.StatusInternalServerError) - return - } - - if ok && errStatus.Code() == codes.NotFound { - http.Error(w, errStatus.String(), http.StatusNotFound) - return - } - - log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } - writeJSONObject(w, toGroupResponse(account, group)) + util.WriteJSONObject(w, toGroupResponse(account, group)) } // CreateGroupHandler handles group creation request func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } var req api.PostApiGroupsJSONRequestBody - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } if req.Name == "" { - http.Error(w, "Group name shouldn't be empty", http.StatusUnprocessableEntity) + util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w) return } @@ -249,55 +243,57 @@ func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) { Peers: peerIPsToKeys(account, req.Peers), } - if err := h.accountManager.SaveGroup(account.Id, &group); err != nil { - log.Errorf("failed creating group \"%s\" under account %s %v", req.Name, account.Id, err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + err = h.accountManager.SaveGroup(account.Id, &group) + if err != nil { + util.WriteError(err, w) return } - writeJSONObject(w, toGroupResponse(account, &group)) + util.WriteJSONObject(w, toGroupResponse(account, &group)) } // DeleteGroupHandler handles group deletion request func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } aID := account.Id groupID := mux.Vars(r)["id"] if len(groupID) == 0 { - http.Error(w, "invalid group ID", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w) return } allGroup, err := account.GetGroupAll() if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + util.WriteError(err, w) return } if allGroup.ID == groupID { - http.Error(w, "deleting group ALL is not allowed", http.StatusMethodNotAllowed) + util.WriteError(status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w) return } - if err := h.accountManager.DeleteGroup(aID, groupID); err != nil { - log.Errorf("failed delete group %s under account %s %v", groupID, aID, err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + err = h.accountManager.DeleteGroup(aID, groupID) + if err != nil { + util.WriteError(err, w) return } - writeJSONObject(w, "") + util.WriteJSONObject(w, "") } // GetGroupHandler returns a group func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } @@ -305,19 +301,22 @@ func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) { case http.MethodGet: groupID := mux.Vars(r)["id"] if len(groupID) == 0 { - http.Error(w, "invalid group ID", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w) return } group, err := h.accountManager.GetGroup(account.Id, groupID) if err != nil { - http.Error(w, "group not found", http.StatusNotFound) + util.WriteError(err, w) return } - writeJSONObject(w, toGroupResponse(account, group)) + util.WriteJSONObject(w, toGroupResponse(account, group)) default: - http.Error(w, "", http.StatusNotFound) + if err != nil { + util.WriteError(status.Errorf(status.NotFound, "HTTP method not found"), w) + return + } } } diff --git a/management/server/http/groups_test.go b/management/server/http/groups_test.go index dbdadbc10..4c6d5b0e0 100644 --- a/management/server/http/groups_test.go +++ b/management/server/http/groups_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/status" "io" "net" "net/http" @@ -36,7 +37,7 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups { }, GetGroupFunc: func(_, groupID string) (*server.Group, error) { if groupID != "idofthegroup" { - return nil, fmt.Errorf("not found") + return nil, status.Errorf(status.NotFound, "not found") } return &server.Group{ ID: "idofthegroup", @@ -67,7 +68,7 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups { } return nil, fmt.Errorf("peer not found") }, - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { + GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return &server.Account{ Id: claims.AccountId, Domain: "hotmail.com", @@ -78,7 +79,7 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups { Groups: map[string]*server.Group{ "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}}, "id-all": {ID: "id-all", Name: "All"}}, - }, nil + }, user, nil }, }, authAudience: "", @@ -223,7 +224,7 @@ func TestWriteGroup(t *testing.T) { requestPath: "/api/groups/id-all", requestBody: bytes.NewBuffer( []byte(`{"Name":"super"}`)), - expectedStatus: http.StatusMethodNotAllowed, + expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, { @@ -244,7 +245,7 @@ func TestWriteGroup(t *testing.T) { requestPath: "/api/groups/id-existed", requestBody: bytes.NewBuffer( []byte(`[{"op":"insert","path":"name","value":[""]}]`)), - expectedStatus: http.StatusBadRequest, + expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, { diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index be4276206..3d12eb4a5 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -1,7 +1,8 @@ package middleware import ( - "fmt" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/status" "net/http" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -33,14 +34,15 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { ok, err := a.isUserAdmin(jwtClaims) if err != nil { - http.Error(w, fmt.Sprintf("error get user from JWT: %v", err), http.StatusUnauthorized) + util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w) return } if !ok { switch r.Method { + case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut: - http.Error(w, "user is not admin", http.StatusForbidden) + util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w) return } } diff --git a/management/server/http/middleware/handler.go b/management/server/http/middleware/handler.go index 82b4c6800..89b8410dd 100644 --- a/management/server/http/middleware/handler.go +++ b/management/server/http/middleware/handler.go @@ -7,12 +7,12 @@ import ( "net/http" ) -//Jwks is a collection of JSONWebKeys obtained from Config.HttpServerConfig.AuthKeysLocation +// Jwks is a collection of JSONWebKeys obtained from Config.HttpServerConfig.AuthKeysLocation type Jwks struct { Keys []JSONWebKeys `json:"keys"` } -//JSONWebKeys is a representation of a Jason Web Key +// JSONWebKeys is a representation of a Jason Web Key type JSONWebKeys struct { Kty string `json:"kty"` Kid string `json:"kid"` @@ -22,7 +22,7 @@ type JSONWebKeys struct { X5c []string `json:"x5c"` } -//NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header +// NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) { keys, err := getPemKeys(keysLocation) @@ -66,7 +66,6 @@ func getPemKeys(keysLocation string) (*Jwks, error) { var jwks = &Jwks{} err = json.NewDecoder(resp.Body).Decode(jwks) - if err != nil { return jwks, err } diff --git a/management/server/http/middleware/jwt.go b/management/server/http/middleware/jwt.go index c8daaa1c4..feb00ec86 100644 --- a/management/server/http/middleware/jwt.go +++ b/management/server/http/middleware/jwt.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "github.com/golang-jwt/jwt" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/status" "log" "net/http" "strings" @@ -57,7 +59,7 @@ type JWTMiddleware struct { } func OnError(w http.ResponseWriter, r *http.Request, err string) { - http.Error(w, err, http.StatusUnauthorized) + util.WriteError(status.Errorf(status.Unauthorized, ""), w) } // New constructs a new Secure instance with supplied options. diff --git a/management/server/http/nameservers.go b/management/server/http/nameservers.go index bfef14f02..77b39c87a 100644 --- a/management/server/http/nameservers.go +++ b/management/server/http/nameservers.go @@ -7,7 +7,9 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/status" log "github.com/sirupsen/logrus" "net/http" ) @@ -30,7 +32,8 @@ func NewNameservers(accountManager server.AccountManager, authAudience string) * // GetAllNameserversHandler returns the list of nameserver groups for the account func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -39,7 +42,7 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re nsGroups, err := h.accountManager.ListNameServerGroups(account.Id) if err != nil { - toHTTPError(err, w) + util.WriteError(err, w) return } @@ -48,64 +51,67 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re apiNameservers = append(apiNameservers, toNameserverGroupResponse(r)) } - writeJSONObject(w, apiNameservers) + util.WriteJSONObject(w, apiNameservers) } // CreateNameserverGroupHandler handles nameserver group creation request func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } + var req api.PostApiDnsNameserversJSONRequestBody - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } nsList, err := toServerNSList(req.Nameservers) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w) return } nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled) if err != nil { - toHTTPError(err, w) + util.WriteError(err, w) return } resp := toNameserverGroupResponse(nsGroup) - writeJSONObject(w, &resp) + util.WriteJSONObject(w, &resp) } // UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } nsGroupID := mux.Vars(r)["id"] if len(nsGroupID) == 0 { - http.Error(w, "invalid nameserver group ID", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) return } var req api.PutApiDnsNameserversIdJSONRequestBody - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } nsList, err := toServerNSList(req.Nameservers) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w) return } @@ -122,41 +128,42 @@ func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *htt err = h.accountManager.SaveNameServerGroup(account.Id, updatedNSGroup) if err != nil { - toHTTPError(err, w) + util.WriteError(err, w) return } resp := toNameserverGroupResponse(updatedNSGroup) - writeJSONObject(w, &resp) + util.WriteJSONObject(w, &resp) } // PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } nsGroupID := mux.Vars(r)["id"] if len(nsGroupID) == 0 { - http.Error(w, "invalid nameserver group ID", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) return } var req api.PatchApiDnsNameserversIdJSONRequestBody - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } var operations []server.NameServerGroupUpdateOperation for _, patch := range req { if patch.Op != api.NameserverGroupPatchOperationOpReplace { - http.Error(w, fmt.Sprintf("nameserver groups only accepts replace operations, got %s", patch.Op), - http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, + "nameserver groups only accepts replace operations, got %s", patch.Op), w) return } switch patch.Path { @@ -196,49 +203,50 @@ func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http Values: patch.Value, }) default: - http.Error(w, "invalid patch path", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w) return } } updatedNSGroup, err := h.accountManager.UpdateNameServerGroup(account.Id, nsGroupID, operations) if err != nil { - toHTTPError(err, w) + util.WriteError(err, w) return } resp := toNameserverGroupResponse(updatedNSGroup) - writeJSONObject(w, &resp) + util.WriteJSONObject(w, &resp) } // DeleteNameserverGroupHandler handles nameserver group deletion request func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } nsGroupID := mux.Vars(r)["id"] if len(nsGroupID) == 0 { - http.Error(w, "invalid nameserver group ID", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) return } err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID) if err != nil { - toHTTPError(err, w) + util.WriteError(err, w) return } - writeJSONObject(w, "") + util.WriteJSONObject(w, "") } // GetNameserverGroupHandler handles a nameserver group Get request identified by ID func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -247,19 +255,19 @@ func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.R nsGroupID := mux.Vars(r)["id"] if len(nsGroupID) == 0 { - http.Error(w, "invalid nameserver group ID", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) return } nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, nsGroupID) if err != nil { - toHTTPError(err, w) + util.WriteError(err, w) return } resp := toNameserverGroupResponse(nsGroup) - writeJSONObject(w, &resp) + util.WriteJSONObject(w, &resp) } diff --git a/management/server/http/nameservers_test.go b/management/server/http/nameservers_test.go index b168a6ee8..037de8155 100644 --- a/management/server/http/nameservers_test.go +++ b/management/server/http/nameservers_test.go @@ -5,9 +5,8 @@ import ( "encoding/json" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/status" "github.com/stretchr/testify/assert" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "io" "net/http" "net/http/httptest" @@ -62,7 +61,7 @@ func initNameserversTestData() *Nameservers { if nsGroupID == existingNSGroupID { return baseExistingNSGroup.Copy(), nil } - return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID) + return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID) }, CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) { return &nbdns.NameServerGroup{ @@ -83,12 +82,12 @@ func initNameserversTestData() *Nameservers { if nsGroupToSave.ID == existingNSGroupID { return nil } - return status.Errorf(codes.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) + return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) }, UpdateNameServerGroupFunc: func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) { nsGroupToUpdate := baseExistingNSGroup.Copy() if nsGroupID != nsGroupToUpdate.ID { - return nil, status.Errorf(codes.NotFound, "nameserver group ID %s no longer exists", nsGroupID) + return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID) } for _, operation := range operations { switch operation.Type { @@ -110,8 +109,8 @@ func initNameserversTestData() *Nameservers { } return nsGroupToUpdate, nil }, - GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) { - return testingNSAccount, nil + GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + return testingNSAccount, nil, nil }, }, authAudience: "", @@ -181,7 +180,7 @@ func TestNameserversHandlers(t *testing.T) { requestPath: "/api/dns/nameservers", requestBody: bytes.NewBuffer( []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")), - expectedStatus: http.StatusBadRequest, + expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, { @@ -223,7 +222,7 @@ func TestNameserversHandlers(t *testing.T) { requestPath: "/api/dns/nameservers/" + notFoundNSGroupID, requestBody: bytes.NewBuffer( []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")), - expectedStatus: http.StatusBadRequest, + expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, { diff --git a/management/server/http/peers.go b/management/server/http/peers.go index 1a7049e01..fe5ff2688 100644 --- a/management/server/http/peers.go +++ b/management/server/http/peers.go @@ -6,8 +6,9 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" - log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/status" "net/http" ) @@ -28,50 +29,47 @@ func NewPeers(accountManager server.AccountManager, authAudience string) *Peers func (h *Peers) updatePeer(account *server.Account, peer *server.Peer, w http.ResponseWriter, r *http.Request) { req := &api.PutApiPeersIdJSONBody{} - peerIp := peer.IP err := json.NewDecoder(r.Body).Decode(&req) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } update := &server.Peer{Key: peer.Key, SSHEnabled: req.SshEnabled, Name: req.Name} peer, err = h.accountManager.UpdatePeer(account.Id, update) if err != nil { - log.Errorf("failed updating peer %s under account %s %v", peerIp, account.Id, err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } - writeJSONObject(w, toPeerResponse(peer, account)) + util.WriteJSONObject(w, toPeerResponse(peer, account)) } func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseWriter, r *http.Request) { _, err := h.accountManager.DeletePeer(accountId, peer.Key) if err != nil { - log.Errorf("failed deleteing peer %s, %v", peer.IP, err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } - writeJSONObject(w, "") + util.WriteJSONObject(w, "") } func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } vars := mux.Vars(r) peerId := vars["id"] //effectively peer IP address if len(peerId) == 0 { - http.Error(w, "invalid peer Id", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid peer ID"), w) return } peer, err := h.accountManager.GetPeerByIP(account.Id, peerId) if err != nil { - http.Error(w, "peer not found", http.StatusNotFound) + util.WriteError(err, w) return } @@ -83,11 +81,11 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) { h.updatePeer(account, peer, w, r) return case http.MethodGet: - writeJSONObject(w, toPeerResponse(peer, account)) + util.WriteJSONObject(w, toPeerResponse(peer, account)) return default: - http.Error(w, "", http.StatusNotFound) + util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w) } } @@ -95,15 +93,16 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) { func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: - account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } peers, err := h.accountManager.GetPeers(account.Id, user.Id) if err != nil { + util.WriteError(err, w) return } @@ -111,10 +110,10 @@ func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) { for _, peer := range peers { respBody = append(respBody, toPeerResponse(peer, account)) } - writeJSONObject(w, respBody) + util.WriteJSONObject(w, respBody) return default: - http.Error(w, "", http.StatusNotFound) + util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w) } } diff --git a/management/server/http/peers_test.go b/management/server/http/peers_test.go index 648ac48b4..e7690b976 100644 --- a/management/server/http/peers_test.go +++ b/management/server/http/peers_test.go @@ -22,7 +22,8 @@ func initTestMetaData(peers ...*server.Peer) *Peers { GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) { return peers, nil }, - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { + GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + user := server.NewAdminUser("test_user") return &server.Account{ Id: claims.AccountId, Domain: "hotmail.com", @@ -30,9 +31,9 @@ func initTestMetaData(peers ...*server.Peer) *Peers { "test_peer": peers[0], }, Users: map[string]*server.User{ - "test_user": server.NewAdminUser("test_user"), + "test_user": user, }, - }, nil + }, user, nil }, }, authAudience: "", diff --git a/management/server/http/routes.go b/management/server/http/routes.go index ad99ea602..f85e3ee20 100644 --- a/management/server/http/routes.go +++ b/management/server/http/routes.go @@ -2,15 +2,13 @@ package http import ( "encoding/json" - "fmt" "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" - log "github.com/sirupsen/logrus" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "net/http" "unicode/utf8" ) @@ -33,25 +31,16 @@ func NewRoutes(accountManager server.AccountManager, authAudience string) *Route // GetAllRoutesHandler returns the list of routes for the account func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) { - account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } routes, err := h.accountManager.ListRoutes(account.Id, user.Id) if err != nil { - log.Error(err) - if e, ok := server.FromError(err); ok { - switch e.Type() { - case server.PermissionDenied: - http.Error(w, e.Error(), http.StatusForbidden) - return - default: - } - } - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } apiRoutes := make([]*api.Route, 0) @@ -59,20 +48,22 @@ func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) { apiRoutes = append(apiRoutes, toRouteResponse(account, r)) } - writeJSONObject(w, apiRoutes) + util.WriteJSONObject(w, apiRoutes) } // CreateRouteHandler handles route creation request func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } var req api.PostApiRoutesJSONRequestBody - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } @@ -80,8 +71,7 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) { if req.Peer != "" { peer, err := h.accountManager.GetPeerByIP(account.Id, req.Peer) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusUnprocessableEntity) + util.WriteError(err, w) return } peerKey = peer.Key @@ -89,57 +79,60 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) { _, newPrefix, err := route.ParseNetwork(req.Network) if err != nil { - http.Error(w, fmt.Sprintf("couldn't parse update prefix %s", req.Network), http.StatusBadRequest) + util.WriteError(err, w) return } if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" { - http.Error(w, fmt.Sprintf("identifier should be between 1 and %d", route.MaxNetIDChar), http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", + route.MaxNetIDChar), w) return } newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix.String(), peerKey, req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Enabled) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } resp := toRouteResponse(account, newRoute) - writeJSONObject(w, &resp) + util.WriteJSONObject(w, &resp) } // UpdateRouteHandler handles update to a route identified by a given ID func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) { - account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } vars := mux.Vars(r) routeID := vars["id"] if len(routeID) == 0 { - http.Error(w, "invalid route Id", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w) return } _, err = h.accountManager.GetRoute(account.Id, routeID, user.Id) if err != nil { - http.Error(w, fmt.Sprintf("couldn't find route for ID %s", routeID), http.StatusNotFound) + util.WriteError(err, w) return } var req api.PutApiRoutesIdJSONRequestBody - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } prefixType, newPrefix, err := route.ParseNetwork(req.Network) if err != nil { - http.Error(w, fmt.Sprintf("couldn't parse update prefix %s for route ID %s", req.Network, routeID), http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "couldn't parse update prefix %s for route ID %s", + req.Network, routeID), w) return } @@ -147,15 +140,15 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) { if req.Peer != "" { peer, err := h.accountManager.GetPeerByIP(account.Id, req.Peer) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusUnprocessableEntity) + util.WriteError(err, w) return } peerKey = peer.Key } if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" { - http.Error(w, fmt.Sprintf("identifier should be between 1 and %d", route.MaxNetIDChar), http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, + "identifier should be between 1 and %d", route.MaxNetIDChar), w) return } @@ -173,46 +166,46 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) { err = h.accountManager.SaveRoute(account.Id, newRoute) if err != nil { - log.Errorf("failed updating route \"%s\" under account %s %v", routeID, account.Id, err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } resp := toRouteResponse(account, newRoute) - writeJSONObject(w, &resp) + util.WriteJSONObject(w, &resp) } // PatchRouteHandler handles patch updates to a route identified by a given ID func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) { - account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } vars := mux.Vars(r) routeID := vars["id"] if len(routeID) == 0 { - http.Error(w, "invalid route ID", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w) return } _, err = h.accountManager.GetRoute(account.Id, routeID, user.Id) if err != nil { - log.Error(err) - http.Error(w, fmt.Sprintf("couldn't find route ID %s", routeID), http.StatusNotFound) + util.WriteError(err, w) return } var req api.PatchApiRoutesIdJSONRequestBody - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } if len(req) == 0 { - http.Error(w, "no patch instruction received", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w) return } @@ -222,8 +215,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) { switch patch.Path { case api.RoutePatchOperationPathNetwork: if patch.Op != api.RoutePatchOperationOpReplace { - http.Error(w, fmt.Sprintf("Network field only accepts replace operation, got %s", patch.Op), - http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, + "network field only accepts replace operation, got %s", patch.Op), w) return } operations = append(operations, server.RouteUpdateOperation{ @@ -232,8 +225,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) { }) case api.RoutePatchOperationPathDescription: if patch.Op != api.RoutePatchOperationOpReplace { - http.Error(w, fmt.Sprintf("Description field only accepts replace operation, got %s", patch.Op), - http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, + "description field only accepts replace operation, got %s", patch.Op), w) return } operations = append(operations, server.RouteUpdateOperation{ @@ -242,8 +235,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) { }) case api.RoutePatchOperationPathNetworkId: if patch.Op != api.RoutePatchOperationOpReplace { - http.Error(w, fmt.Sprintf("Network Identifier field only accepts replace operation, got %s", patch.Op), - http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, + "network Identifier field only accepts replace operation, got %s", patch.Op), w) return } operations = append(operations, server.RouteUpdateOperation{ @@ -252,21 +245,20 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) { }) case api.RoutePatchOperationPathPeer: if patch.Op != api.RoutePatchOperationOpReplace { - http.Error(w, fmt.Sprintf("Peer field only accepts replace operation, got %s", patch.Op), - http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, + "peer field only accepts replace operation, got %s", patch.Op), w) return } if len(patch.Value) > 1 { - http.Error(w, fmt.Sprintf("Value field only accepts 1 value, got %d", len(patch.Value)), - http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, + "value field only accepts 1 value, got %d", len(patch.Value)), w) return } peerValue := patch.Value if patch.Value[0] != "" { peer, err := h.accountManager.GetPeerByIP(account.Id, patch.Value[0]) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusUnprocessableEntity) + util.WriteError(err, w) return } peerValue = []string{peer.Key} @@ -277,8 +269,9 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) { }) case api.RoutePatchOperationPathMetric: if patch.Op != api.RoutePatchOperationOpReplace { - http.Error(w, fmt.Sprintf("Metric field only accepts replace operation, got %s", patch.Op), - http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, + "metric field only accepts replace operation, got %s", patch.Op), w) + return } operations = append(operations, server.RouteUpdateOperation{ @@ -287,8 +280,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) { }) case api.RoutePatchOperationPathMasquerade: if patch.Op != api.RoutePatchOperationOpReplace { - http.Error(w, fmt.Sprintf("Masquerade field only accepts replace operation, got %s", patch.Op), - http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, + "masquerade field only accepts replace operation, got %s", patch.Op), w) return } operations = append(operations, server.RouteUpdateOperation{ @@ -297,8 +290,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) { }) case api.RoutePatchOperationPathEnabled: if patch.Op != api.RoutePatchOperationOpReplace { - http.Error(w, fmt.Sprintf("Enabled field only accepts replace operation, got %s", patch.Op), - http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, + "enabled field only accepts replace operation, got %s", patch.Op), w) return } operations = append(operations, server.RouteUpdateOperation{ @@ -306,90 +299,68 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) { Values: patch.Value, }) default: - http.Error(w, "invalid patch path", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w) return } } route, err := h.accountManager.UpdateRoute(account.Id, routeID, operations) - if err != nil { - errStatus, ok := status.FromError(err) - if ok && errStatus.Code() == codes.Internal { - http.Error(w, errStatus.String(), http.StatusInternalServerError) - return - } - - if ok && errStatus.Code() == codes.NotFound { - http.Error(w, errStatus.String(), http.StatusNotFound) - return - } - - if ok && errStatus.Code() == codes.InvalidArgument { - http.Error(w, errStatus.String(), http.StatusBadRequest) - return - } - - log.Errorf("failed updating route %s under account %s %v", routeID, account.Id, err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } resp := toRouteResponse(account, route) - writeJSONObject(w, &resp) + util.WriteJSONObject(w, &resp) } // DeleteRouteHandler handles route deletion request func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } routeID := mux.Vars(r)["id"] if len(routeID) == 0 { - http.Error(w, "invalid route ID", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w) return } err = h.accountManager.DeleteRoute(account.Id, routeID) if err != nil { - errStatus, ok := status.FromError(err) - if ok && errStatus.Code() == codes.NotFound { - http.Error(w, fmt.Sprintf("route %s not found under account %s", routeID, account.Id), http.StatusNotFound) - return - } - log.Errorf("failed delete route %s under account %s %v", routeID, account.Id, err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } - writeJSONObject(w, "") + util.WriteJSONObject(w, "") } // GetRouteHandler handles a route Get request identified by ID func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) { - account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } routeID := mux.Vars(r)["id"] if len(routeID) == 0 { - http.Error(w, "invalid route ID", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w) return } foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id) if err != nil { - http.Error(w, "route not found", http.StatusNotFound) + util.WriteError(status.Errorf(status.NotFound, "route not found"), w) return } - writeJSONObject(w, toRouteResponse(account, foundRoute)) + util.WriteJSONObject(w, toRouteResponse(account, foundRoute)) } func toRouteResponse(account *server.Account, serverRoute *route.Route) *api.Route { diff --git a/management/server/http/routes_test.go b/management/server/http/routes_test.go index aaaf2f6b8..cd3822585 100644 --- a/management/server/http/routes_test.go +++ b/management/server/http/routes_test.go @@ -5,9 +5,8 @@ import ( "encoding/json" "fmt" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "io" "net/http" "net/http/httptest" @@ -63,7 +62,7 @@ func initRoutesTestData() *Routes { if routeID == existingRouteID { return baseExistingRoute, nil } - return nil, status.Errorf(codes.NotFound, "route with ID %s not found", routeID) + return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) }, CreateRouteFunc: func(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) { networkType, p, _ := route.ParseNetwork(network) @@ -83,13 +82,13 @@ func initRoutesTestData() *Routes { }, DeleteRouteFunc: func(_ string, peerIP string) error { if peerIP != existingRouteID { - return status.Errorf(codes.NotFound, "Peer with ID %s not found", peerIP) + return status.Errorf(status.NotFound, "Peer with ID %s not found", peerIP) } return nil }, GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) { if peerIP != existingPeerID { - return nil, status.Errorf(codes.NotFound, "Peer with ID %s not found", peerIP) + return nil, status.Errorf(status.NotFound, "Peer with ID %s not found", peerIP) } return &server.Peer{ Key: existingPeerKey, @@ -99,7 +98,7 @@ func initRoutesTestData() *Routes { UpdateRouteFunc: func(_ string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) { routeToUpdate := baseExistingRoute if routeID != routeToUpdate.ID { - return nil, status.Errorf(codes.NotFound, "route %s no longer exists", routeID) + return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID) } for _, operation := range operations { switch operation.Type { @@ -123,8 +122,8 @@ func initRoutesTestData() *Routes { } return routeToUpdate, nil }, - GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) { - return testingAccount, nil + GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + return testingAccount, testingAccount.Users["test_user"], nil }, }, authAudience: "", @@ -201,15 +200,15 @@ func TestRoutesHandlers(t *testing.T) { requestType: http.MethodPost, requestPath: "/api/routes", requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", notFoundPeerID)), - expectedStatus: http.StatusUnprocessableEntity, + expectedStatus: http.StatusNotFound, expectedBody: false, }, { - name: "POST Not Invalid Network Identifier", + name: "POST Invalid Network Identifier", requestType: http.MethodPost, requestPath: "/api/routes", requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"12345678901234567890qwertyuiopqwertyuiop1\",\"Peer\":\"%s\"}", existingPeerID)), - expectedStatus: http.StatusBadRequest, + expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, { @@ -217,7 +216,7 @@ func TestRoutesHandlers(t *testing.T) { requestType: http.MethodPost, requestPath: "/api/routes", requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/34\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", existingPeerID)), - expectedStatus: http.StatusBadRequest, + expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, { @@ -251,7 +250,7 @@ func TestRoutesHandlers(t *testing.T) { requestType: http.MethodPut, requestPath: "/api/routes/" + existingRouteID, requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", notFoundPeerID)), - expectedStatus: http.StatusUnprocessableEntity, + expectedStatus: http.StatusNotFound, expectedBody: false, }, { @@ -259,7 +258,7 @@ func TestRoutesHandlers(t *testing.T) { requestType: http.MethodPut, requestPath: "/api/routes/" + existingRouteID, requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"12345678901234567890qwertyuiopqwertyuiop1\",\"Peer\":\"%s\"}", existingPeerID)), - expectedStatus: http.StatusBadRequest, + expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, { @@ -267,7 +266,7 @@ func TestRoutesHandlers(t *testing.T) { requestType: http.MethodPut, requestPath: "/api/routes/" + existingRouteID, requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/34\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", existingPeerID)), - expectedStatus: http.StatusBadRequest, + expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, { @@ -312,7 +311,7 @@ func TestRoutesHandlers(t *testing.T) { requestType: http.MethodPatch, requestPath: "/api/routes/" + existingRouteID, requestBody: bytes.NewBufferString(fmt.Sprintf("[{\"op\":\"replace\",\"path\":\"peer\",\"value\":[\"%s\"]}]", notFoundPeerID)), - expectedStatus: http.StatusUnprocessableEntity, + expectedStatus: http.StatusNotFound, expectedBody: false, }, { diff --git a/management/server/http/rules.go b/management/server/http/rules.go index 1842f5557..9f1219185 100644 --- a/management/server/http/rules.go +++ b/management/server/http/rules.go @@ -2,15 +2,13 @@ package http import ( "encoding/json" - "fmt" "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/status" "github.com/rs/xid" - log "github.com/sirupsen/logrus" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "net/http" ) @@ -31,25 +29,16 @@ func NewRules(accountManager server.AccountManager, authAudience string) *Rules // GetAllRulesHandler list for the account func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) { - account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } accountRules, err := h.accountManager.ListRules(account.Id, user.Id) if err != nil { - log.Error(err) - if e, ok := server.FromError(err); ok { - switch e.Type() { - case server.PermissionDenied: - http.Error(w, e.Error(), http.StatusForbidden) - return - default: - } - } - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } rules := []*api.Rule{} @@ -57,38 +46,39 @@ func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) { rules = append(rules, toRuleResponse(account, r)) } - writeJSONObject(w, rules) + util.WriteJSONObject(w, rules) } // UpdateRuleHandler handles update to a rule identified by a given ID func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } vars := mux.Vars(r) ruleID := vars["id"] if len(ruleID) == 0 { - http.Error(w, "invalid rule Id", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w) return } _, ok := account.Rules[ruleID] if !ok { - http.Error(w, fmt.Sprintf("couldn't find rule id %s", ruleID), http.StatusNotFound) + util.WriteError(status.Errorf(status.NotFound, "couldn't find rule id %s", ruleID), w) return } var req api.PutApiRulesIdJSONRequestBody - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) } if req.Name == "" { - http.Error(w, "Rule name shouldn't be empty", http.StatusUnprocessableEntity) + util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w) return } @@ -115,50 +105,52 @@ func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) { case server.TrafficFlowBidirectString: rule.Flow = server.TrafficFlowBidirect default: - http.Error(w, "unknown flow type", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "unknown flow type"), w) return } - if err := h.accountManager.SaveRule(account.Id, &rule); err != nil { - log.Errorf("failed updating rule \"%s\" under account %s %v", ruleID, account.Id, err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + err = h.accountManager.SaveRule(account.Id, &rule) + if err != nil { + util.WriteError(err, w) return } resp := toRuleResponse(account, &rule) - writeJSONObject(w, &resp) + util.WriteJSONObject(w, &resp) } // PatchRuleHandler handles patch updates to a rule identified by a given ID func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } vars := mux.Vars(r) ruleID := vars["id"] if len(ruleID) == 0 { - http.Error(w, "invalid rule Id", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w) return } _, ok := account.Rules[ruleID] if !ok { - http.Error(w, fmt.Sprintf("couldn't find rule id %s", ruleID), http.StatusNotFound) + util.WriteError(status.Errorf(status.NotFound, "couldn't find rule ID %s", ruleID), w) return } var req api.PatchApiRulesIdJSONRequestBody - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } if len(req) == 0 { - http.Error(w, "no patch instruction received", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w) return } @@ -168,12 +160,12 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) { switch patch.Path { case api.RulePatchOperationPathName: if patch.Op != api.RulePatchOperationOpReplace { - http.Error(w, fmt.Sprintf("Name field only accepts replace operation, got %s", patch.Op), - http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, + "name field only accepts replace operation, got %s", patch.Op), w) return } if len(patch.Value) == 0 || patch.Value[0] == "" { - http.Error(w, "Rule name shouldn't be empty", http.StatusUnprocessableEntity) + util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w) return } operations = append(operations, server.RuleUpdateOperation{ @@ -182,8 +174,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) { }) case api.RulePatchOperationPathDescription: if patch.Op != api.RulePatchOperationOpReplace { - http.Error(w, fmt.Sprintf("Description field only accepts replace operation, got %s", patch.Op), - http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, + "description field only accepts replace operation, got %s", patch.Op), w) return } operations = append(operations, server.RuleUpdateOperation{ @@ -192,8 +184,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) { }) case api.RulePatchOperationPathFlow: if patch.Op != api.RulePatchOperationOpReplace { - http.Error(w, fmt.Sprintf("Flow field only accepts replace operation, got %s", patch.Op), - http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, + "flow field only accepts replace operation, got %s", patch.Op), w) return } operations = append(operations, server.RuleUpdateOperation{ @@ -202,8 +194,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) { }) case api.RulePatchOperationPathDisabled: if patch.Op != api.RulePatchOperationOpReplace { - http.Error(w, fmt.Sprintf("Disabled field only accepts replace operation, got %s", patch.Op), - http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, + "disabled field only accepts replace operation, got %s", patch.Op), w) return } operations = append(operations, server.RuleUpdateOperation{ @@ -228,7 +220,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) { Values: patch.Value, }) default: - http.Error(w, "invalid operation, \"%s\", for Source field", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, + "invalid operation \"%s\" on Source field", patch.Op), w) return } case api.RulePatchOperationPathDestinations: @@ -249,11 +242,12 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) { Values: patch.Value, }) default: - http.Error(w, "invalid operation, \"%s\", for Destination field", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, + "invalid operation \"%s\" on Destination field", patch.Op), w) return } default: - http.Error(w, "invalid patch path", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w) return } } @@ -261,48 +255,33 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) { rule, err := h.accountManager.UpdateRule(account.Id, ruleID, operations) if err != nil { - errStatus, ok := status.FromError(err) - if ok && errStatus.Code() == codes.Internal { - http.Error(w, errStatus.String(), http.StatusInternalServerError) - return - } - - if ok && errStatus.Code() == codes.NotFound { - http.Error(w, errStatus.String(), http.StatusNotFound) - return - } - - if ok && errStatus.Code() == codes.InvalidArgument { - http.Error(w, errStatus.String(), http.StatusBadRequest) - return - } - - log.Errorf("failed updating rule %s under account %s %v", ruleID, account.Id, err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } resp := toRuleResponse(account, rule) - writeJSONObject(w, &resp) + util.WriteJSONObject(w, &resp) } // CreateRuleHandler handles rule creation request func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } var req api.PostApiRulesJSONRequestBody - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } if req.Name == "" { - http.Error(w, "Rule name shouldn't be empty", http.StatusUnprocessableEntity) + util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w) return } @@ -329,50 +308,52 @@ func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) { case server.TrafficFlowBidirectString: rule.Flow = server.TrafficFlowBidirect default: - http.Error(w, "unknown flow type", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "unknown flow type"), w) return } - if err := h.accountManager.SaveRule(account.Id, &rule); err != nil { - log.Errorf("failed creating rule \"%s\" under account %s %v", req.Name, account.Id, err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + err = h.accountManager.SaveRule(account.Id, &rule) + if err != nil { + util.WriteError(err, w) return } resp := toRuleResponse(account, &rule) - writeJSONObject(w, &resp) + util.WriteJSONObject(w, &resp) } // DeleteRuleHandler handles rule deletion request func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } aID := account.Id rID := mux.Vars(r)["id"] if len(rID) == 0 { - http.Error(w, "invalid rule ID", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w) return } - if err := h.accountManager.DeleteRule(aID, rID); err != nil { - log.Errorf("failed delete rule %s under account %s %v", rID, aID, err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + err = h.accountManager.DeleteRule(aID, rID) + if err != nil { + util.WriteError(err, w) return } - writeJSONObject(w, "") + util.WriteJSONObject(w, "") } // GetRuleHandler handles a group Get request identified by ID func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) { - account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } @@ -380,19 +361,19 @@ func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) { case http.MethodGet: ruleID := mux.Vars(r)["id"] if len(ruleID) == 0 { - http.Error(w, "invalid rule ID", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w) return } rule, err := h.accountManager.GetRule(account.Id, ruleID, user.Id) if err != nil { - http.Error(w, "rule not found", http.StatusNotFound) + util.WriteError(status.Errorf(status.NotFound, "rule not found"), w) return } - writeJSONObject(w, toRuleResponse(account, rule)) + util.WriteJSONObject(w, toRuleResponse(account, rule)) default: - http.Error(w, "", http.StatusNotFound) + util.WriteError(status.Errorf(status.NotFound, "method not found"), w) } } diff --git a/management/server/http/rules_test.go b/management/server/http/rules_test.go index f6f20f0d5..f5b7a8b98 100644 --- a/management/server/http/rules_test.go +++ b/management/server/http/rules_test.go @@ -66,7 +66,8 @@ func initRulesTestData(rules ...*server.Rule) *Rules { } return &rule, nil }, - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { + GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + user := server.NewAdminUser("test_user") return &server.Account{ Id: claims.AccountId, Domain: "hotmail.com", @@ -76,9 +77,9 @@ func initRulesTestData(rules ...*server.Rule) *Rules { "G": {ID: "G"}, }, Users: map[string]*server.User{ - "test_user": server.NewAdminUser("test_user"), + "test_user": user, }, - }, nil + }, user, nil }, }, authAudience: "", @@ -238,7 +239,7 @@ func TestRulesWriteRule(t *testing.T) { requestPath: "/api/rules/id-existed", requestBody: bytes.NewBuffer( []byte(`[{"op":"insert","path":"name","value":[""]}]`)), - expectedStatus: http.StatusBadRequest, + expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, { diff --git a/management/server/http/setupkeys.go b/management/server/http/setupkeys.go index a07dacd60..6af797eb0 100644 --- a/management/server/http/setupkeys.go +++ b/management/server/http/setupkeys.go @@ -2,14 +2,12 @@ package http import ( "encoding/json" - "fmt" "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" - log "github.com/sirupsen/logrus" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/management/server/status" "net/http" "time" ) @@ -31,29 +29,28 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authAudience stri // CreateSetupKeyHandler is a POST requests that creates a new SetupKey func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } req := &api.PostApiSetupKeysJSONRequestBody{} err = json.NewDecoder(r.Body).Decode(&req) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } if req.Name == "" { - http.Error(w, "Setup key name shouldn't be empty", http.StatusUnprocessableEntity) + util.WriteError(status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w) return } if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable || server.SetupKeyType(req.Type) == server.SetupKeyOneOff) { - - http.Error(w, "unknown setup key type "+string(req.Type), http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "unknown setup key type %s", string(req.Type)), w) return } @@ -62,17 +59,11 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request if req.AutoGroups == nil { req.AutoGroups = []string{} } - // newExpiresIn := time.Duration(req.ExpiresIn) * time.Second - // newKey.ExpiresAt = time.Now().Add(newExpiresIn) + setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn, req.AutoGroups) if err != nil { - errStatus, ok := status.FromError(err) - if ok && errStatus.Code() == codes.NotFound { - http.Error(w, "account not found", http.StatusNotFound) - return - } - http.Error(w, "failed adding setup key", http.StatusInternalServerError) + util.WriteError(err, w) return } @@ -81,29 +72,23 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request // GetSetupKeyHandler is a GET request to get a SetupKey by ID func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) { - account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } vars := mux.Vars(r) keyID := vars["id"] if len(keyID) == 0 { - http.Error(w, "invalid key Id", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w) return } key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID) if err != nil { - errStatus, ok := status.FromError(err) - if ok && errStatus.Code() == codes.NotFound { - http.Error(w, fmt.Sprintf("setup key %s not found under account %s", keyID, account.Id), http.StatusNotFound) - return - } - log.Errorf("failed getting setup key %s under account %s %v", keyID, account.Id, err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } @@ -112,34 +97,34 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) { // UpdateSetupKeyHandler is a PUT request to update server.SetupKey func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) { - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } vars := mux.Vars(r) keyID := vars["id"] if len(keyID) == 0 { - http.Error(w, "invalid key Id", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w) return } req := &api.PutApiSetupKeysIdJSONRequestBody{} err = json.NewDecoder(r.Body).Decode(&req) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } if req.Name == "" { - http.Error(w, fmt.Sprintf("setup key name field is invalid: %s", req.Name), http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w) return } if req.AutoGroups == nil { - http.Error(w, fmt.Sprintf("setup key AutoGroups field is invalid: %s", req.AutoGroups), http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w) return } @@ -150,16 +135,8 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request newKey.Id = keyID newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey) - if err != nil { - if e, ok := status.FromError(err); ok { - switch e.Code() { - case codes.NotFound: - http.Error(w, fmt.Sprintf("couldn't find setup key for ID %s", keyID), http.StatusNotFound) - default: - http.Error(w, "failed updating setup key", http.StatusInternalServerError) - } - } + util.WriteError(err, w) return } writeSuccess(w, newKey) @@ -168,25 +145,25 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request // GetAllSetupKeysHandler is a GET request that returns a list of SetupKey func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) { - account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } + apiSetupKeys := make([]*api.SetupKey, 0) for _, key := range setupKeys { apiSetupKeys = append(apiSetupKeys, toResponseBody(key)) } - writeJSONObject(w, apiSetupKeys) + util.WriteJSONObject(w, apiSetupKeys) } func writeSuccess(w http.ResponseWriter, key *server.SetupKey) { @@ -194,7 +171,7 @@ func writeSuccess(w http.ResponseWriter, key *server.SetupKey) { w.Header().Set("Content-Type", "application/json") err := json.NewEncoder(w).Encode(toResponseBody(key)) if err != nil { - http.Error(w, "failed handling request", http.StatusInternalServerError) + util.WriteError(err, w) return } } diff --git a/management/server/http/setupkeys_test.go b/management/server/http/setupkeys_test.go index 87d7c53b4..fbb8a9f2c 100644 --- a/management/server/http/setupkeys_test.go +++ b/management/server/http/setupkeys_test.go @@ -6,9 +6,8 @@ import ( "fmt" "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/status" "github.com/stretchr/testify/assert" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "io" "net/http" "net/http/httptest" @@ -32,7 +31,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup user *server.User) *SetupKeys { return &SetupKeys{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { + GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return &server.Account{ Id: testAccountID, Domain: "hotmail.com", @@ -45,7 +44,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup Groups: map[string]*server.Group{ "group-1": {ID: "group-1", Peers: []string{"A", "B"}}, "id-all": {ID: "id-all", Name: "All"}}, - }, nil + }, user, nil }, CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string) (*server.SetupKey, error) { if keyName == newKey.Name || typ != newKey.Type { @@ -60,7 +59,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup case newKey.Id: return newKey, nil default: - return nil, status.Errorf(codes.NotFound, "key %s not found", keyID) + return nil, status.Errorf(status.NotFound, "key %s not found", keyID) } }, @@ -68,7 +67,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup if key.Id == updatedSetupKey.Id { return updatedSetupKey, nil } - return nil, status.Errorf(codes.NotFound, "key %s not found", key.Id) + return nil, status.Errorf(status.NotFound, "key %s not found", key.Id) }, ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) { diff --git a/management/server/http/users.go b/management/server/http/users.go index f7b227c40..698bb9410 100644 --- a/management/server/http/users.go +++ b/management/server/http/users.go @@ -2,12 +2,10 @@ package http import ( "encoding/json" - "fmt" "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server/http/api" - log "github.com/sirupsen/logrus" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/status" "net/http" "github.com/netbirdio/netbird/management/server" @@ -31,33 +29,34 @@ func NewUserHandler(accountManager server.AccountManager, authAudience string) * // UpdateUser is a PUT requests to update User data func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPut { - http.Error(w, "", http.StatusBadRequest) + util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) + return } - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } vars := mux.Vars(r) userID := vars["id"] if len(userID) == 0 { - http.Error(w, "invalid user ID", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } req := &api.PutApiUsersIdJSONRequestBody{} err = json.NewDecoder(r.Body).Decode(&req) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } userRole := server.StrRoleToUserRole(req.Role) if userRole == server.UserRoleUnknown { - http.Error(w, "invalid user role", http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w) return } @@ -67,40 +66,36 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { AutoGroups: req.AutoGroups, }) if err != nil { - if e, ok := status.FromError(err); ok { - switch e.Code() { - case codes.NotFound: - http.Error(w, fmt.Sprintf("couldn't find a user for ID %s", userID), http.StatusNotFound) - default: - http.Error(w, "failed to update user", http.StatusInternalServerError) - } - } + util.WriteError(err, w) return } - writeJSONObject(w, toUserResponse(newUser)) + util.WriteJSONObject(w, toUserResponse(newUser)) } // CreateUserHandler creates a User in the system with a status "invited" (effectively this is a user invite). func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { - http.Error(w, "", http.StatusNotFound) + util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) + return } - account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, _, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - log.Error(err) + util.WriteError(err, w) + return } req := &api.PostApiUsersJSONRequestBody{} err = json.NewDecoder(r.Body).Decode(&req) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown { - http.Error(w, "unknown user role "+req.Role, http.StatusBadRequest) + util.WriteError(status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w) return } @@ -111,37 +106,30 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request) AutoGroups: req.AutoGroups, }) if err != nil { - if e, ok := server.FromError(err); ok { - switch e.Type() { - case server.UserAlreadyExists: - http.Error(w, "You can't invite users with an existing NetBird account.", http.StatusPreconditionFailed) - return - default: - } - } - http.Error(w, "failed to invite", http.StatusInternalServerError) + util.WriteError(err, w) return } - writeJSONObject(w, toUserResponse(newUser)) + util.WriteJSONObject(w, toUserResponse(newUser)) } // GetUsers returns a list of users of the account this user belongs to. // It also gathers additional user data (like email and name) from the IDP manager. func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { - http.Error(w, "", http.StatusBadRequest) + util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) + return } - account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id) if err != nil { - log.Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) + util.WriteError(err, w) return } @@ -150,7 +138,7 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) { users = append(users, toUserResponse(r)) } - writeJSONObject(w, users) + util.WriteJSONObject(w, users) } func toUserResponse(user *server.UserInfo) *api.User { diff --git a/management/server/http/users_test.go b/management/server/http/users_test.go index e402712c3..806c6152d 100644 --- a/management/server/http/users_test.go +++ b/management/server/http/users_test.go @@ -16,7 +16,7 @@ import ( func initUsers(user ...*server.User) *UserHandler { return &UserHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { + GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { users := make(map[string]*server.User, 0) for _, u := range user { users[u.Id] = u @@ -25,7 +25,7 @@ func initUsers(user ...*server.User) *UserHandler { Id: "12345", Domain: "netbird.io", Users: users, - }, nil + }, users[claims.UserId], nil }, GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) { users := make([]*server.UserInfo, 0) @@ -66,7 +66,6 @@ func TestGetUsers(t *testing.T) { expectedResult []*server.User }{ {name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users/", expectedStatus: http.StatusOK, expectedResult: users}, - {name: "WrongRequestMethod", requestType: http.MethodPost, requestPath: "/api/users/", expectedStatus: http.StatusBadRequest}, } for _, tc := range tt { diff --git a/management/server/http/util.go b/management/server/http/util.go deleted file mode 100644 index dbf28fdb8..000000000 --- a/management/server/http/util.go +++ /dev/null @@ -1,97 +0,0 @@ -package http - -import ( - "encoding/json" - "errors" - "fmt" - "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/jwtclaims" - log "github.com/sirupsen/logrus" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "net/http" - "time" -) - -// writeJSONObject simply writes object to the HTTP reponse in JSON format -func writeJSONObject(w http.ResponseWriter, obj interface{}) { - w.WriteHeader(http.StatusOK) - w.Header().Set("Content-Type", "application/json; charset=UTF-8") - err := json.NewEncoder(w).Encode(obj) - if err != nil { - http.Error(w, "failed handling request", http.StatusInternalServerError) - return - } -} - -// Duration is used strictly for JSON requests/responses due to duration marshalling issues -type Duration struct { - time.Duration -} - -func (d Duration) MarshalJSON() ([]byte, error) { - return json.Marshal(d.String()) -} - -func (d *Duration) UnmarshalJSON(b []byte) error { - var v interface{} - if err := json.Unmarshal(b, &v); err != nil { - return err - } - switch value := v.(type) { - case float64: - d.Duration = time.Duration(value) - return nil - case string: - var err error - d.Duration, err = time.ParseDuration(value) - if err != nil { - return err - } - return nil - default: - return errors.New("invalid duration") - } -} - -func getJWTAccount(accountManager server.AccountManager, - jwtExtractor jwtclaims.ClaimsExtractor, - authAudience string, r *http.Request) (*server.Account, *server.User, error) { - - claims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience) - - account, err := accountManager.GetAccountFromToken(claims) - if err != nil { - return nil, nil, fmt.Errorf("failed getting account of a user %s: %v", claims.UserId, err) - } - - user := account.Users[claims.UserId] - if user == nil { - // this is not really possible because we got an account by user ID - return nil, nil, fmt.Errorf("user %s not found", claims.UserId) - } - - return account, user, nil -} - -func toHTTPError(err error, w http.ResponseWriter) { - errStatus, ok := status.FromError(err) - if ok && errStatus.Code() == codes.Internal { - http.Error(w, errStatus.String(), http.StatusInternalServerError) - return - } - - if ok && errStatus.Code() == codes.NotFound { - http.Error(w, errStatus.String(), http.StatusNotFound) - return - } - - if ok && errStatus.Code() == codes.InvalidArgument { - http.Error(w, errStatus.String(), http.StatusBadRequest) - return - } - - unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", errStatus.String()) - log.Error(unhandledMSG) - http.Error(w, unhandledMSG, http.StatusInternalServerError) -} diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go new file mode 100644 index 000000000..0055511a2 --- /dev/null +++ b/management/server/http/util/util.go @@ -0,0 +1,105 @@ +package util + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/netbirdio/netbird/management/server/status" + log "github.com/sirupsen/logrus" + "net/http" + "time" +) + +// WriteJSONObject simply writes object to the HTTP reponse in JSON format +func WriteJSONObject(w http.ResponseWriter, obj interface{}) { + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + err := json.NewEncoder(w).Encode(obj) + if err != nil { + WriteError(err, w) + return + } +} + +// Duration is used strictly for JSON requests/responses due to duration marshalling issues +type Duration struct { + time.Duration +} + +// MarshalJSON marshals the duration +func (d Duration) MarshalJSON() ([]byte, error) { + return json.Marshal(d.String()) +} + +// UnmarshalJSON unmarshals the duration +func (d *Duration) UnmarshalJSON(b []byte) error { + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + switch value := v.(type) { + case float64: + d.Duration = time.Duration(value) + return nil + case string: + var err error + d.Duration, err = time.ParseDuration(value) + if err != nil { + return err + } + return nil + default: + return errors.New("invalid duration") + } +} + +// WriteErrorResponse prepares and writes an error response i nJSON +func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) { + type errorResponse struct { + Message string `json:"message"` + Code int `json:"code"` + } + + w.WriteHeader(httpStatus) + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + err := json.NewEncoder(w).Encode(&errorResponse{ + Message: errMsg, + Code: httpStatus, + }) + if err != nil { + http.Error(w, "failed handling request", http.StatusInternalServerError) + } +} + +// WriteError converts an error to an JSON error response. +// If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise +func WriteError(err error, w http.ResponseWriter) { + errStatus, ok := status.FromError(err) + httpStatus := http.StatusInternalServerError + msg := "internal server error" + if ok { + switch errStatus.Type() { + case status.UserAlreadyExists: + httpStatus = http.StatusConflict + case status.AlreadyExists: + httpStatus = http.StatusConflict + case status.PreconditionFailed: + httpStatus = http.StatusPreconditionFailed + case status.PermissionDenied: + httpStatus = http.StatusForbidden + case status.NotFound: + httpStatus = http.StatusNotFound + case status.Internal: + httpStatus = http.StatusInternalServerError + case status.InvalidArgument: + httpStatus = http.StatusUnprocessableEntity + default: + } + msg = err.Error() + } else { + unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error()) + log.Error(unhandledMSG) + } + + WriteErrorResponse(msg, httpStatus, w) +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 4ed231f64..71d460dee 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -59,7 +59,7 @@ type MockAccountManager struct { DeleteNameServerGroupFunc func(accountID, nsGroupID string) error ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error) CreateUserFunc func(accountID string, key *server.UserInfo) (*server.UserInfo, error) - GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) + GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface @@ -113,7 +113,7 @@ func (am *MockAccountManager) CreateSetupKey( return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") } -// GetAccountByUserOrAccountId mock implementation of GetAccountByUserOrAccountId from server.AccountManager interface +// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface func (am *MockAccountManager) GetAccountByUserOrAccountID( userId, accountId, domain string, ) (*server.Account, error) { @@ -462,11 +462,12 @@ func (am *MockAccountManager) CreateUser(accountID string, invite *server.UserIn } // GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface -func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { +func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, + error) { if am.GetAccountFromTokenFunc != nil { return am.GetAccountFromTokenFunc(claims) } - return nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented") + return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented") } // GetPeers mocks GetPeers of the AccountManager interface diff --git a/management/server/nameserver.go b/management/server/nameserver.go index b3dbf333a..4c4190c77 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -3,10 +3,9 @@ package server import ( "github.com/miekg/dns" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/status" "github.com/rs/xid" log "github.com/sirupsen/logrus" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "strconv" "unicode/utf8" ) @@ -66,7 +65,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } nsGroup, found := account.NameServerGroups[nsGroupID] @@ -74,7 +73,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) return nsGroup.Copy(), nil } - return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID) + return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID) } // CreateNameServerGroup creates and saves a new nameserver group @@ -85,7 +84,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } newNSGroup := &nbdns.NameServerGroup{ @@ -119,7 +118,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d err = am.updateAccountPeers(account) if err != nil { log.Error(err) - return newNSGroup.Copy(), status.Errorf(codes.Unavailable, "failed to update peers after create nameserver %s", name) + return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after create nameserver %s", name) } return newNSGroup.Copy(), nil @@ -132,12 +131,12 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupTo defer unlock() if nsGroupToSave == nil { - return status.Errorf(codes.InvalidArgument, "nameserver group provided is nil") + return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") } account, err := am.Store.GetAccount(accountID) if err != nil { - return status.Errorf(codes.NotFound, "account not found") + return err } err = validateNameServerGroup(true, nsGroupToSave, account) @@ -156,7 +155,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupTo err = am.updateAccountPeers(account) if err != nil { log.Error(err) - return status.Errorf(codes.Unavailable, "failed to update peers after update nameserver %s", nsGroupToSave.Name) + return status.Errorf(status.Internal, "failed to update peers after update nameserver %s", nsGroupToSave.Name) } return nil @@ -170,16 +169,16 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } if len(operations) == 0 { - return nil, status.Errorf(codes.InvalidArgument, "operations shouldn't be empty") + return nil, status.Errorf(status.InvalidArgument, "operations shouldn't be empty") } nsGroupToUpdate, ok := account.NameServerGroups[nsGroupID] if !ok { - return nil, status.Errorf(codes.NotFound, "nameserver group ID %s no longer exists", nsGroupID) + return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID) } newNSGroup := nsGroupToUpdate.Copy() @@ -187,12 +186,12 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri for _, operation := range operations { valuesCount := len(operation.Values) if valuesCount < 1 { - return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String()) + return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String()) } for _, value := range operation.Values { if value == "" { - return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String()) + return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String()) } } switch operation.Type { @@ -200,7 +199,7 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri newNSGroup.Description = operation.Values[0] case UpdateNameServerGroupName: if valuesCount > 1 { - return nil, status.Errorf(codes.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount) + return nil, status.Errorf(status.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount) } err = validateNSGroupName(operation.Values[0], nsGroupID, account.NameServerGroups) if err != nil { @@ -230,13 +229,13 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri case UpdateNameServerGroupEnabled: enabled, err := strconv.ParseBool(operation.Values[0]) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0]) + return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0]) } newNSGroup.Enabled = enabled case UpdateNameServerGroupPrimary: primary, err := strconv.ParseBool(operation.Values[0]) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0]) + return nil, status.Errorf(status.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0]) } newNSGroup.Primary = primary case UpdateNameServerGroupDomains: @@ -259,7 +258,7 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri err = am.updateAccountPeers(account) if err != nil { log.Error(err) - return newNSGroup.Copy(), status.Errorf(codes.Unavailable, "failed to update peers after update nameserver %s", newNSGroup.Name) + return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after update nameserver %s", newNSGroup.Name) } return newNSGroup.Copy(), nil @@ -273,7 +272,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID stri account, err := am.Store.GetAccount(accountID) if err != nil { - return status.Errorf(codes.NotFound, "account not found") + return err } delete(account.NameServerGroups, nsGroupID) @@ -287,7 +286,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID stri err = am.updateAccountPeers(account) if err != nil { log.Error(err) - return status.Errorf(codes.Unavailable, "failed to update peers after deleting nameserver %s", nsGroupID) + return status.Errorf(status.Internal, "failed to update peers after deleting nameserver %s", nsGroupID) } return nil @@ -301,7 +300,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdn account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups)) @@ -318,7 +317,7 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ nsGroupID = nameserverGroup.ID _, found := account.NameServerGroups[nsGroupID] if !found { - return status.Errorf(codes.NotFound, "nameserver group with ID %s was not found", nsGroupID) + return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID) } } @@ -347,17 +346,17 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ func validateDomainInput(primary bool, domains []string) error { if !primary && len(domains) == 0 { - return status.Errorf(codes.InvalidArgument, "nameserver group primary status is false and domains are empty,"+ + return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+ " it should be primary or have at least one domain") } if primary && len(domains) != 0 { - return status.Errorf(codes.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+ + return status.Errorf(status.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+ " you should set either primary or domain") } for _, domain := range domains { _, valid := dns.IsDomainName(domain) if !valid { - return status.Errorf(codes.InvalidArgument, "nameserver group got an invalid domain: %s", domain) + return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s", domain) } } return nil @@ -365,12 +364,12 @@ func validateDomainInput(primary bool, domains []string) error { func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error { if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" { - return status.Errorf(codes.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar) + return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar) } for _, nsGroup := range nsGroupMap { if name == nsGroup.Name && nsGroup.ID != nsGroupID { - return status.Errorf(codes.InvalidArgument, "a nameserver group with name %s already exist", name) + return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name) } } @@ -380,19 +379,19 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na func validateNSList(list []nbdns.NameServer) error { nsListLenght := len(list) if nsListLenght == 0 || nsListLenght > 2 { - return status.Errorf(codes.InvalidArgument, "the list of nameservers should be 1 or 2, got %d", len(list)) + return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 2, got %d", len(list)) } return nil } func validateGroups(list []string, groups map[string]*Group) error { if len(list) == 0 { - return status.Errorf(codes.InvalidArgument, "the list of group IDs should not be empty") + return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty") } for _, id := range list { if id == "" { - return status.Errorf(codes.InvalidArgument, "group ID should not be empty string") + return status.Errorf(status.InvalidArgument, "group ID should not be empty string") } found := false for groupID := range groups { @@ -402,7 +401,7 @@ func validateGroups(list []string, groups map[string]*Group) error { } } if !found { - return status.Errorf(codes.InvalidArgument, "group id %s not found", id) + return status.Errorf(status.InvalidArgument, "group id %s not found", id) } } diff --git a/management/server/network.go b/management/server/network.go index 97cb7a1de..77ff92787 100644 --- a/management/server/network.go +++ b/management/server/network.go @@ -3,6 +3,7 @@ package server import ( "github.com/c-robinson/iplib" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" "github.com/rs/xid" "math/rand" @@ -93,7 +94,7 @@ func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) { ips, _ := generateIPs(&ipNet, takenIPMap) if len(ips) == 0 { - return nil, Errorf(PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String()) + return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String()) } // pick a random IP diff --git a/management/server/peer.go b/management/server/peer.go index 7c28d3013..9ecd5dd0f 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -2,6 +2,7 @@ package server import ( nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/status" "net" "strings" "time" @@ -9,8 +10,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/proto" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) // PeerSystemMeta is a metadata of a Peer machine system @@ -162,7 +161,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Pe account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } //TODO Peer.ID migration: we will need to replace search by ID here @@ -208,7 +207,7 @@ func (am *DefaultAccountManager) DeletePeer(accountID string, peerPubKey string) account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } peer, err := account.FindPeerByPubKey(peerPubKey) @@ -258,7 +257,7 @@ func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (* account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } for _, peer := range account.Peers { @@ -267,7 +266,7 @@ func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (* } } - return nil, status.Errorf(codes.NotFound, "peer with IP %s not found", peerIP) + return nil, status.Errorf(status.NotFound, "peer with IP %s not found", peerIP) } // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) @@ -275,7 +274,7 @@ func (am *DefaultAccountManager) GetNetworkMap(peerPubKey string) (*NetworkMap, account, err := am.Store.GetAccountByPeerPubKey(peerPubKey) if err != nil { - return nil, status.Errorf(codes.Internal, "Invalid peer key %s", peerPubKey) + return nil, err } aclPeers := am.getPeersByACL(account, peerPubKey) @@ -306,7 +305,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(peerPubKey string) (*Network, er account, err := am.Store.GetAccountByPeerPubKey(peerPubKey) if err != nil { - return nil, status.Errorf(codes.Internal, "invalid peer key %s", peerPubKey) + return nil, err } return account.Network.Copy(), err @@ -332,7 +331,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (* account, err = am.Store.GetAccountBySetupKey(setupKey) } if err != nil { - return nil, Errorf(AccountNotFound, "failed adding new peer: account not found") + return nil, status.Errorf(status.NotFound, "failed adding new peer: account not found") } unlock := am.Store.AcquireAccountLock(account.Id) @@ -352,7 +351,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (* } if !sk.IsValid() { - return nil, Errorf(PreconditionFailed, "couldn't add peer: setup key is invalid") + return nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") } account.SetupKeys[sk.Key] = sk.IncrementUsage() @@ -418,7 +417,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (* account.Network.IncSerial() err = am.Store.SaveAccount(account) if err != nil { - return nil, status.Errorf(codes.Internal, "failed adding peer") + return nil, err } return newPeer, nil @@ -563,13 +562,13 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) error { for _, peer := range peers { remotePeerNetworkMap, err := am.GetNetworkMap(peer.Key) if err != nil { - return status.Errorf(codes.Internal, "unable to fetch network map for peer %s, error: %v", peer.Key, err) + return err } update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap) err = am.peersUpdateManager.SendUpdate(peer.Key, &UpdateMessage{Update: update}) if err != nil { - return status.Errorf(codes.Internal, "unable to send update for peer %s, error: %v", peer.Key, err) + return err } } diff --git a/management/server/route.go b/management/server/route.go index 9f615f9e0..feef82533 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -2,11 +2,10 @@ package server import ( "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" "github.com/rs/xid" log "github.com/sirupsen/logrus" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "net/netip" "strconv" "unicode/utf8" @@ -66,7 +65,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } user, err := account.FindUser(userID) @@ -75,7 +74,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r } if !user.IsAdmin() { - return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes") + return nil, status.Errorf(status.PermissionDenied, "Only administrators can view Network Routes") } wantedRoute, found := account.Routes[routeID] @@ -83,7 +82,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r return wantedRoute, nil } - return nil, status.Errorf(codes.NotFound, "route with ID %s not found", routeID) + return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) } // checkPrefixPeerExists checks the combination of prefix and peer id, if it exists returns an error, otherwise returns nil @@ -101,14 +100,14 @@ func (am *DefaultAccountManager) checkPrefixPeerExists(accountID, peer string, p routesWithPrefix := account.GetRoutesByPrefix(prefix) if err != nil { - if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { + if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { return nil } - return status.Errorf(codes.InvalidArgument, "failed to parse prefix %s", prefix.String()) + return status.Errorf(status.InvalidArgument, "failed to parse prefix %s", prefix.String()) } for _, prefixRoute := range routesWithPrefix { if prefixRoute.Peer == peer { - return status.Errorf(codes.AlreadyExists, "failed a route with prefix %s and peer already exist", prefix.String()) + return status.Errorf(status.AlreadyExists, "failed a route with prefix %s and peer already exist", prefix.String()) } } return nil @@ -121,13 +120,13 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } var newRoute route.Route prefixType, newPrefix, err := route.ParseNetwork(network) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to parse IP %s", network) + return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", network) } err = am.checkPrefixPeerExists(accountID, peer, newPrefix) if err != nil { @@ -137,16 +136,16 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de if peer != "" { _, peerExist := account.Peers[peer] if !peerExist { - return nil, status.Errorf(codes.InvalidArgument, "failed to find Peer %s", peer) + return nil, status.Errorf(status.InvalidArgument, "failed to find Peer %s", peer) } } if metric < route.MinMetric || metric > route.MaxMetric { - return nil, status.Errorf(codes.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) + return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) } if utf8.RuneCountInString(netID) > route.MaxNetIDChar || netID == "" { - return nil, status.Errorf(codes.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) + return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) } newRoute.Peer = peer @@ -173,7 +172,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de err = am.updateAccountPeers(account) if err != nil { log.Error(err) - return &newRoute, status.Errorf(codes.Unavailable, "failed to update peers after create route %s", newPrefix) + return &newRoute, status.Errorf(status.Internal, "failed to update peers after create route %s", newPrefix) } return &newRoute, nil } @@ -184,30 +183,30 @@ func (am *DefaultAccountManager) SaveRoute(accountID string, routeToSave *route. defer unlock() if routeToSave == nil { - return status.Errorf(codes.InvalidArgument, "route provided is nil") + return status.Errorf(status.InvalidArgument, "route provided is nil") } if !routeToSave.Network.IsValid() { - return status.Errorf(codes.InvalidArgument, "invalid Prefix %s", routeToSave.Network.String()) + return status.Errorf(status.InvalidArgument, "invalid Prefix %s", routeToSave.Network.String()) } if routeToSave.Metric < route.MinMetric || routeToSave.Metric > route.MaxMetric { - return status.Errorf(codes.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) + return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) } if utf8.RuneCountInString(routeToSave.NetID) > route.MaxNetIDChar || routeToSave.NetID == "" { - return status.Errorf(codes.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) + return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) } account, err := am.Store.GetAccount(accountID) if err != nil { - return status.Errorf(codes.NotFound, "account not found") + return err } if routeToSave.Peer != "" { _, peerExist := account.Peers[routeToSave.Peer] if !peerExist { - return status.Errorf(codes.InvalidArgument, "failed to find Peer %s", routeToSave.Peer) + return status.Errorf(status.InvalidArgument, "failed to find Peer %s", routeToSave.Peer) } } @@ -228,12 +227,12 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } routeToUpdate, ok := account.Routes[routeID] if !ok { - return nil, status.Errorf(codes.NotFound, "route %s no longer exists", routeID) + return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID) } newRoute := routeToUpdate.Copy() @@ -241,7 +240,7 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio for _, operation := range operations { if len(operation.Values) != 1 { - return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid number of values, it should be 1", operation.Type.String()) + return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be 1", operation.Type.String()) } switch operation.Type { @@ -249,13 +248,13 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio newRoute.Description = operation.Values[0] case UpdateRouteNetworkIdentifier: if utf8.RuneCountInString(operation.Values[0]) > route.MaxNetIDChar || operation.Values[0] == "" { - return nil, status.Errorf(codes.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) + return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) } newRoute.NetID = operation.Values[0] case UpdateRouteNetwork: prefixType, prefix, err := route.ParseNetwork(operation.Values[0]) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to parse IP %s", operation.Values[0]) + return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", operation.Values[0]) } err = am.checkPrefixPeerExists(accountID, routeToUpdate.Peer, prefix) if err != nil { @@ -267,7 +266,7 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio if operation.Values[0] != "" { _, peerExist := account.Peers[operation.Values[0]] if !peerExist { - return nil, status.Errorf(codes.InvalidArgument, "failed to find Peer %s", operation.Values[0]) + return nil, status.Errorf(status.InvalidArgument, "failed to find Peer %s", operation.Values[0]) } } @@ -279,10 +278,10 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio case UpdateRouteMetric: metric, err := strconv.Atoi(operation.Values[0]) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to parse metric %s, not int", operation.Values[0]) + return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, not int", operation.Values[0]) } if metric < route.MinMetric || metric > route.MaxMetric { - return nil, status.Errorf(codes.InvalidArgument, "failed to parse metric %s, value should be %d > N < %d", + return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, value should be %d > N < %d", operation.Values[0], route.MinMetric, route.MaxMetric, @@ -292,13 +291,13 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio case UpdateRouteMasquerade: masquerade, err := strconv.ParseBool(operation.Values[0]) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to parse masquerade %s, not boolean", operation.Values[0]) + return nil, status.Errorf(status.InvalidArgument, "failed to parse masquerade %s, not boolean", operation.Values[0]) } newRoute.Masquerade = masquerade case UpdateRouteEnabled: enabled, err := strconv.ParseBool(operation.Values[0]) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0]) + return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0]) } newRoute.Enabled = enabled } @@ -313,7 +312,7 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio err = am.updateAccountPeers(account) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to update account peers") + return nil, status.Errorf(status.Internal, "failed to update account peers") } return newRoute, nil } @@ -325,7 +324,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error { account, err := am.Store.GetAccount(accountID) if err != nil { - return status.Errorf(codes.NotFound, "account not found") + return err } delete(account.Routes, routeID) @@ -345,7 +344,7 @@ func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route. account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } user, err := account.FindUser(userID) @@ -354,7 +353,7 @@ func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route. } if !user.IsAdmin() { - return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes") + return nil, status.Errorf(status.PermissionDenied, "Only administrators can view Network Routes") } routes := make([]*route.Route, 0, len(account.Routes)) diff --git a/management/server/rule.go b/management/server/rule.go index dd0cf5fa9..98c74c02c 100644 --- a/management/server/rule.go +++ b/management/server/rule.go @@ -1,8 +1,7 @@ package server import ( - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/management/server/status" "strings" ) @@ -95,7 +94,7 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rul account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } user, err := account.FindUser(userID) @@ -104,7 +103,7 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rul } if !user.IsAdmin() { - return nil, Errorf(PermissionDenied, "only admins are allowed to view rules") + return nil, status.Errorf(status.PermissionDenied, "only admins are allowed to view rules") } rule, ok := account.Rules[ruleID] @@ -112,7 +111,7 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rul return rule, nil } - return nil, status.Errorf(codes.NotFound, "rule with ID %s not found", ruleID) + return nil, status.Errorf(status.NotFound, "rule with ID %s not found", ruleID) } // SaveRule of ACL in the store @@ -122,7 +121,7 @@ func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error { account, err := am.Store.GetAccount(accountID) if err != nil { - return status.Errorf(codes.NotFound, "account not found") + return err } account.Rules[rule.ID] = rule @@ -143,12 +142,12 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string, account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } ruleToUpdate, ok := account.Rules[ruleID] if !ok { - return nil, status.Errorf(codes.NotFound, "rule %s no longer exists", ruleID) + return nil, status.Errorf(status.NotFound, "rule %s no longer exists", ruleID) } rule := ruleToUpdate.Copy() @@ -161,7 +160,7 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string, rule.Description = operation.Values[0] case UpdateRuleFlow: if operation.Values[0] != TrafficFlowBidirectString { - return nil, status.Errorf(codes.InvalidArgument, "failed to parse flow") + return nil, status.Errorf(status.InvalidArgument, "failed to parse flow") } rule.Flow = TrafficFlowBidirect case UpdateRuleStatus: @@ -170,7 +169,7 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string, } else if strings.ToLower(operation.Values[0]) == "false" { rule.Disabled = false } else { - return nil, status.Errorf(codes.InvalidArgument, "failed to parse status") + return nil, status.Errorf(status.InvalidArgument, "failed to parse status") } case UpdateSourceGroups: rule.Source = operation.Values @@ -204,7 +203,7 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string, err = am.updateAccountPeers(account) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to update account peers") + return nil, status.Errorf(status.Internal, "failed to update account peers") } return rule, nil @@ -217,7 +216,7 @@ func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error { account, err := am.Store.GetAccount(accountID) if err != nil { - return status.Errorf(codes.NotFound, "account not found") + return err } delete(account.Rules, ruleID) @@ -237,7 +236,7 @@ func (am *DefaultAccountManager) ListRules(accountID, userID string) ([]*Rule, e account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } user, err := account.FindUser(userID) @@ -246,7 +245,7 @@ func (am *DefaultAccountManager) ListRules(accountID, userID string) ([]*Rule, e } if !user.IsAdmin() { - return nil, Errorf(PermissionDenied, "Only Administrators can view Access Rules") + return nil, status.Errorf(status.PermissionDenied, "Only Administrators can view Access Rules") } rules := make([]*Rule, 0, len(account.Rules)) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index c91af8c79..9d7d36379 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -1,10 +1,8 @@ package server import ( - "fmt" "github.com/google/uuid" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/management/server/status" "hash/fnv" "strconv" "strings" @@ -183,12 +181,12 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } for _, group := range autoGroups { if _, ok := account.Groups[group]; !ok { - return nil, fmt.Errorf("group %s doesn't exist", group) + return nil, status.Errorf(status.NotFound, "group %s doesn't exist", group) } } @@ -197,7 +195,7 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string err = am.Store.SaveAccount(account) if err != nil { - return nil, status.Errorf(codes.Internal, "failed adding account key") + return nil, status.Errorf(status.Internal, "failed adding account key") } return setupKey, nil @@ -212,12 +210,12 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup defer unlock() if keyToSave == nil { - return nil, status.Errorf(codes.InvalidArgument, "provided setup key to update is nil") + return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") } account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } var oldKey *SetupKey @@ -228,7 +226,7 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup } } if oldKey == nil { - return nil, status.Errorf(codes.NotFound, "setup key not found") + return nil, status.Errorf(status.NotFound, "setup key not found") } // only auto groups, revoked status, and name can be updated for now @@ -253,7 +251,7 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } user, err := account.FindUser(userID) @@ -282,7 +280,7 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (* account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } user, err := account.FindUser(userID) @@ -298,7 +296,7 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (* } } if foundKey == nil { - return nil, status.Errorf(codes.NotFound, "setup key not found") + return nil, status.Errorf(status.NotFound, "setup key not found") } // the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file) diff --git a/management/server/status/error.go b/management/server/status/error.go new file mode 100644 index 000000000..6d6299449 --- /dev/null +++ b/management/server/status/error.go @@ -0,0 +1,72 @@ +package status + +import ( + "fmt" +) + +const ( + // UserAlreadyExists indicates that user already exists + UserAlreadyExists Type = 1 + + // PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled + PreconditionFailed Type = 2 + + // PermissionDenied indicates that user has no permissions to view data + PermissionDenied Type = 3 + + // NotFound indicates that the object wasn't found in the system (or under a given Account) + NotFound Type = 4 + + // Internal indicates some generic internal error + Internal Type = 5 + + // InvalidArgument indicates some generic invalid argument error + InvalidArgument Type = 6 + + // AlreadyExists indicates a generic error when an object already exists in the system + AlreadyExists Type = 7 + + // Unauthorized indicates that user is not authorized + Unauthorized Type = 8 + + // BadRequest indicates that user is not authorized + BadRequest Type = 9 +) + +// Type is a type of the Error +type Type int32 + +// Error is an internal error +type Error struct { + ErrorType Type + Message string +} + +// Type returns the Type of the error +func (e *Error) Type() Type { + return e.ErrorType +} + +// Error is an error string +func (e *Error) Error() string { + return e.Message +} + +// Errorf returns Error(ErrorType, fmt.Sprintf(format, a...)). +func Errorf(errorType Type, format string, a ...interface{}) error { + return &Error{ + ErrorType: errorType, + Message: fmt.Sprintf(format, a...), + } +} + +// FromError returns Error, true if the provided error is of type of Error. nil, false otherwise +func FromError(err error) (s *Error, ok bool) { + if err == nil { + return nil, true + } + if e, ok := err.(*Error); ok { + return e, true + } + return nil, false +} diff --git a/management/server/user.go b/management/server/user.go index a0e4a6870..82775c726 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -3,8 +3,7 @@ package server import ( "fmt" "github.com/netbirdio/netbird/management/server/idp" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/management/server/status" "strings" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -123,7 +122,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo) defer unlock() if am.idpManager == nil { - return nil, Errorf(PreconditionFailed, "IdP manager must be enabled to send user invites") + return nil, status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites") } if invite == nil { @@ -132,7 +131,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo) account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, Errorf(AccountNotFound, "account %s doesn't exist", accountID) + return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID) } // check if the user is already registered with this email => reject @@ -142,7 +141,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo) } if user != nil { - return nil, Errorf(UserAlreadyExists, "user has an existing account") + return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account") } users, err := am.idpManager.GetUserByEmail(invite.Email) @@ -151,7 +150,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo) } if len(users) > 0 { - return nil, Errorf(UserAlreadyExists, "user has an existing account") + return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account") } idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID) @@ -188,25 +187,24 @@ func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*User defer unlock() if update == nil { - return nil, status.Errorf(codes.InvalidArgument, "provided user update is nil") + return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } account, err := am.Store.GetAccount(accountID) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") + return nil, err } for _, newGroupID := range update.AutoGroups { if _, ok := account.Groups[newGroupID]; !ok { - return nil, - status.Errorf(codes.InvalidArgument, "provided group ID %s in the user %s update doesn't exist", - newGroupID, update.Id) + return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist", + newGroupID, update.Id) } } oldUser := account.Users[update.Id] if oldUser == nil { - return nil, status.Errorf(codes.NotFound, "update not found") + return nil, status.Errorf(status.NotFound, "update not found") } // only auto groups, revoked status, and name can be updated for now @@ -226,7 +224,7 @@ func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*User return nil, err } if userData == nil { - return nil, status.Errorf(codes.NotFound, "user %s not found in the IdP", newUser.Id) + return nil, status.Errorf(status.NotFound, "user %s not found in the IdP", newUser.Id) } return newUser.toUserInfo(userData) } @@ -242,14 +240,14 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string) account, err := am.Store.GetAccountByUser(userID) if err != nil { - if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { + if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { account, err = am.newAccount(userID, lowerDomain) if err != nil { return nil, err } err = am.Store.SaveAccount(account) if err != nil { - return nil, status.Errorf(codes.Internal, "failed creating account") + return nil, err } } else { // other error @@ -263,7 +261,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string) account.Domain = lowerDomain err = am.Store.SaveAccount(account) if err != nil { - return nil, status.Errorf(codes.Internal, "failed updating account with domain") + return nil, status.Errorf(status.Internal, "failed updating account with domain") } } @@ -272,14 +270,14 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string) // IsUserAdmin flag for current user authenticated by JWT token func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) { - account, err := am.GetAccountFromToken(claims) + account, _, err := am.GetAccountFromToken(claims) if err != nil { return false, fmt.Errorf("get account: %v", err) } user, ok := account.Users[claims.UserId] if !ok { - return false, fmt.Errorf("no such user") + return false, status.Errorf(status.NotFound, "user not found") } return user.Role == UserRoleAdmin, nil diff --git a/route/route.go b/route/route.go index 6214a6464..4173783df 100644 --- a/route/route.go +++ b/route/route.go @@ -1,8 +1,7 @@ package route import ( - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/management/server/status" "net/netip" ) @@ -108,13 +107,13 @@ func (r *Route) IsEqual(other *Route) bool { func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) { prefix, err := netip.ParsePrefix(networkString) if err != nil { - return InvalidNetwork, netip.Prefix{}, err + return InvalidNetwork, netip.Prefix{}, status.Errorf(status.InvalidArgument, "invalid network %s", networkString) } masked := prefix.Masked() if !masked.IsValid() { - return InvalidNetwork, netip.Prefix{}, status.Errorf(codes.InvalidArgument, "invalid range %s", networkString) + return InvalidNetwork, netip.Prefix{}, status.Errorf(status.InvalidArgument, "invalid range %s", networkString) } if masked.Addr().Is6() {