Replace gRPC errors in business logic with internal ones (#558)

This commit is contained in:
Misha Bragin 2022-11-11 20:36:45 +01:00 committed by GitHub
parent 1db4027bea
commit 509d23c7cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 768 additions and 847 deletions

View File

@ -8,12 +8,11 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"math/rand"
"net"
"net/netip"
@ -52,7 +51,7 @@ type AccountManager interface {
SaveUser(accountID string, key *User) (*UserInfo, error)
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExists(accountId string) (*bool, error)
GetPeer(peerKey string) (*Peer, error)
@ -265,14 +264,14 @@ func (a *Account) FindPeerByPubKey(peerPubKey string) (*Peer, error) {
}
}
return nil, status.Errorf(codes.NotFound, "peer with the public key %s not found", peerPubKey)
return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", peerPubKey)
}
// FindUser looks for a given user in the Account or returns error if user wasn't found.
func (a *Account) FindUser(userID string) (*User, error) {
user := a.Users[userID]
if user == nil {
return nil, Errorf(UserNotFound, "user %s not found", userID)
return nil, status.Errorf(status.NotFound, "user %s not found", userID)
}
return user, nil
@ -282,7 +281,7 @@ func (a *Account) FindUser(userID string) (*User, error) {
func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) {
key := a.SetupKeys[setupKey]
if key == nil {
return nil, Errorf(SetupKeyNotFound, "setup key not found")
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return key, nil
@ -458,14 +457,14 @@ func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, er
if err == nil {
log.Warnf("an account with ID already exists, retrying...")
continue
} else if statusErr.Code() == codes.NotFound {
} else if statusErr.Type() == status.NotFound {
return newAccountWithId(accountId, userID, domain), nil
} else {
return nil, err
}
}
return nil, status.Errorf(codes.Internal, "error while creating new account")
return nil, status.Errorf(status.Internal, "error while creating new account")
}
func (am *DefaultAccountManager) warmupIDPCache() error {
@ -492,7 +491,7 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID,
} else if userID != "" {
account, err := am.GetOrCreateAccountByUser(userID, domain)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userID)
return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID)
}
err = am.addAccountIDToIDPAppMeta(userID, account)
if err != nil {
@ -501,7 +500,7 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID,
return account, nil
}
return nil, status.Errorf(codes.NotFound, "no valid user or account Id provided")
return nil, status.Errorf(status.NotFound, "no valid user or account Id provided")
}
func isNil(i idp.Manager) bool {
@ -531,11 +530,7 @@ func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account
}
if err != nil {
return status.Errorf(
codes.Internal,
"updating user's app metadata failed with: %v",
err,
)
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
}
// refresh cache to reflect the update
_, err = am.refreshCache(account.Id)
@ -662,11 +657,8 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a
}
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
func (am *DefaultAccountManager) updateAccountDomainAttributes(
account *Account,
claims jwtclaims.AuthorizationClaims,
primaryDomain bool,
) error {
func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims,
primaryDomain bool) error {
account.IsDomainPrimaryAccount = primaryDomain
lowerDomain := strings.ToLower(claims.Domain)
@ -681,7 +673,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(
err := am.Store.SaveAccount(account)
if err != nil {
return status.Errorf(codes.Internal, "failed saving updated account")
return err
}
return nil
}
@ -723,10 +715,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
// otherwise it will create a new account and make it primary account for the domain.
func (am *DefaultAccountManager) handleNewUserAccount(
domainAcc *Account,
claims jwtclaims.AuthorizationClaims,
) (*Account, error) {
func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
var (
account *Account
err error
@ -738,7 +727,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed saving updated account")
return nil, err
}
} else {
account, err = am.newAccount(claims.UserId, lowerDomain)
@ -773,7 +762,7 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
}
if user == nil {
return status.Errorf(codes.NotFound, "user %s not found in the IdP", userID)
return status.Errorf(status.NotFound, "user %s not found in the IdP", userID)
}
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
@ -794,7 +783,7 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
}
// GetAccountFromToken returns an account associated with this token
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) {
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
@ -806,15 +795,21 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
account, err := am.getAccountWithAuthorizationClaims(claims)
if err != nil {
return nil, err
return nil, nil, err
}
user := account.Users[claims.UserId]
if user == nil {
// this is not really possible because we got an account by user ID
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
}
err = am.redeemInvite(account, claims.UserId)
if err != nil {
return nil, err
return nil, nil, err
}
return account, nil
return account, user, nil
}
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
@ -857,10 +852,13 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
// We checked if the domain has a primary account already
domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain)
accStatus, _ := status.FromError(err)
if accStatus.Code() != codes.OK && accStatus.Code() != codes.NotFound {
if err != nil {
// if NotFound we are good to continue, otherwise return error
e, ok := status.FromError(err)
if !ok || e.Type() != status.NotFound {
return nil, err
}
}
account, err := am.Store.GetAccountByUser(claims.UserId)
if err == nil {
@ -869,7 +867,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
return nil, err
}
return account, nil
} else if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
return am.handleNewUserAccount(domainAccount, claims)
} else {
// other error
@ -891,7 +889,7 @@ func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error)
var res bool
_, err := am.Store.GetAccount(accountID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
res = false
return &res, nil
} else {

View File

@ -314,7 +314,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id
}
account, err := manager.GetAccountFromToken(testCase.inputClaims)
account, _, err := manager.GetAccountFromToken(testCase.inputClaims)
require.NoError(t, err, "support function failed")
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)

View File

@ -1,61 +0,0 @@
package server
import (
"fmt"
)
const (
// UserAlreadyExists indicates that user already exists
UserAlreadyExists ErrorType = iota
// AccountNotFound indicates that specified account hasn't been found
AccountNotFound
// PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled
PreconditionFailed
// UserNotFound indicates that user wasn't found in the system (or under a given Account)
UserNotFound
// PermissionDenied indicates that user has no permissions to view data
PermissionDenied
// SetupKeyNotFound indicates that the setup key wasn't found in the system (or under a given Account)
SetupKeyNotFound
)
// ErrorType is a type of the Error
type ErrorType int32
// Error is an internal error
type Error struct {
errorType ErrorType
message string
}
// Type returns the Type of the error
func (e *Error) Type() ErrorType {
return e.errorType
}
// Error is an error string
func (e *Error) Error() string {
return e.message
}
// Errorf returns Error(errorType, fmt.Sprintf(format, a...)).
func Errorf(errorType ErrorType, format string, a ...interface{}) error {
return &Error{
errorType: errorType,
message: fmt.Sprintf(format, a...),
}
}
// FromError returns Error, true if the provided error is of type of Error. nil, false otherwise
func FromError(err error) (s *Error, ok bool) {
if err == nil {
return nil, true
}
if e, ok := err.(*Error); ok {
return e, true
}
return nil, false
}

View File

@ -1,6 +1,7 @@
package server
import (
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
"os"
"path/filepath"
@ -8,9 +9,6 @@ import (
"sync"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/util"
)
@ -192,10 +190,7 @@ func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
accountID, accountIDFound := s.PrivateDomain2AccountID[strings.ToLower(domain)]
if !accountIDFound {
return nil, status.Errorf(
codes.NotFound,
"account not found: provided domain is not registered or is not private",
)
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}
account, err := s.getAccount(accountID)
@ -213,7 +208,7 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
accountID, accountIDFound := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !accountIDFound {
return nil, status.Errorf(codes.NotFound, "account not found: provided setup key doesn't exists")
return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
}
account, err := s.getAccount(accountID)
@ -239,7 +234,7 @@ func (s *FileStore) GetAllAccounts() (all []*Account) {
func (s *FileStore) getAccount(accountID string) (*Account, error) {
account, accountFound := s.Accounts[accountID]
if !accountFound {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, status.Errorf(status.NotFound, "account not found")
}
return account, nil
@ -265,7 +260,7 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
accountID, accountIDFound := s.UserID2AccountID[userID]
if !accountIDFound {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, status.Errorf(status.NotFound, "account not found")
}
account, err := s.getAccount(accountID)
@ -283,7 +278,7 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
accountID, accountIDFound := s.PeerKeyID2AccountID[peerKey]
if !accountIDFound {
return nil, status.Errorf(codes.NotFound, "Provided peer key doesn't exists %s", peerKey)
return nil, status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
}
account, err := s.getAccount(accountID)
@ -322,7 +317,7 @@ func (s *FileStore) SavePeerStatus(accountID, peerKey string, peerStatus PeerSta
peer := account.Peers[peerKey]
if peer == nil {
return status.Errorf(codes.NotFound, "peer %s not found", peerKey)
return status.Errorf(status.NotFound, "peer %s not found", peerKey)
}
peer.Status = &peerStatus

View File

@ -1,9 +1,6 @@
package server
import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
import "github.com/netbirdio/netbird/management/server/status"
// Group of the peers for ACL
type Group struct {
@ -53,7 +50,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
group, ok := account.Groups[groupID]
@ -61,7 +58,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er
return group, nil
}
return nil, status.Errorf(codes.NotFound, "group with ID %s not found", groupID)
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
// SaveGroup object of the peers
@ -72,7 +69,7 @@ func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
return err
}
account.Groups[group.ID] = group
@ -94,12 +91,12 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string,
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
groupToUpdate, ok := account.Groups[groupID]
if !ok {
return nil, status.Errorf(codes.NotFound, "group %s no longer exists", groupID)
return nil, status.Errorf(status.NotFound, "group with ID %s no longer exists", groupID)
}
group := groupToUpdate.Copy()
@ -130,7 +127,7 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string,
err = am.updateAccountPeers(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update account peers")
return nil, err
}
return group, nil
@ -144,7 +141,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
return err
}
delete(account.Groups, groupID)
@ -165,7 +162,7 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
groups := make([]*Group, 0, len(account.Groups))
@ -184,12 +181,12 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
return err
}
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(codes.NotFound, "group with ID %s not found", groupID)
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
add := true
@ -219,12 +216,12 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
return err
}
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(codes.NotFound, "group with ID %s not found", groupID)
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
account.Network.IncSerial()
@ -232,7 +229,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str
if itemID == peerKey {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(account); err != nil {
return status.Errorf(codes.Internal, "can't save account")
return err
}
}
}
@ -248,12 +245,12 @@ func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*P
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, status.Errorf(status.NotFound, "account not found")
}
group, ok := account.Groups[groupID]
if !ok {
return nil, status.Errorf(codes.NotFound, "group with ID %s not found", groupID)
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
peers := make([]*Peer, 0, len(account.Groups))

View File

@ -14,6 +14,7 @@ import (
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
internalStatus "github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
@ -200,10 +201,6 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err)
}
claims := jwtclaims.ExtractClaimsWithToken(token, s.config.HttpConfig.AuthAudience)
_, err = s.accountManager.GetAccountFromToken(claims)
if err != nil {
return nil, status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
}
userID = claims.UserId
} else {
log.Debugln("using setup key to register peer")
@ -237,14 +234,12 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
},
})
if err != nil {
if e, ok := FromError(err); ok {
if e, ok := internalStatus.FromError(err); ok {
switch e.Type() {
case PreconditionFailed:
return nil, status.Errorf(codes.FailedPrecondition, e.message)
case AccountNotFound:
case SetupKeyNotFound:
case UserNotFound:
return nil, status.Errorf(codes.NotFound, e.message)
case internalStatus.PreconditionFailed:
return nil, status.Errorf(codes.FailedPrecondition, e.Message)
case internalStatus.NotFound:
return nil, status.Errorf(codes.NotFound, e.Message)
default:
}
}
@ -301,7 +296,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
peer, err := s.accountManager.GetPeer(peerKey.String())
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.NotFound {
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
// peer doesn't exist -> check if setup key was provided
if loginReq.GetJwtToken() == "" && loginReq.GetSetupKey() == "" {
// absent setup key or jwt -> permission denied
@ -387,7 +382,6 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
case TCP:
return proto.HostConfig_TCP
default:
// mbragin: todo something better?
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
}
}

View File

@ -2,10 +2,9 @@ package http
import (
"encoding/json"
"fmt"
"github.com/netbirdio/netbird/management/server/http/api"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"net/http"
"github.com/netbirdio/netbird/management/server"
@ -33,7 +32,8 @@ func NewGroups(accountManager server.AccountManager, authAudience string) *Group
// GetAllGroupsHandler list for the account
func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -45,52 +45,54 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
groups = append(groups, toGroupResponse(account, g))
}
writeJSONObject(w, groups)
util.WriteJSONObject(w, groups)
}
// UpdateGroupHandler handles update to a group identified by a given ID
func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
groupID, ok := vars["id"]
if !ok {
http.Error(w, "group ID field is missing", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
return
}
if len(groupID) == 0 {
http.Error(w, "group ID can't be empty", http.StatusUnprocessableEntity)
util.WriteError(status.Errorf(status.InvalidArgument, "group ID can't be empty"), w)
return
}
_, ok = account.Groups[groupID]
if !ok {
http.Error(w, fmt.Sprintf("couldn't find group with ID %s", groupID), http.StatusNotFound)
util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
return
}
allGroup, err := account.GetGroupAll()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
util.WriteError(err, w)
return
}
if allGroup.ID == groupID {
http.Error(w, "updating group ALL is not allowed", http.StatusMethodNotAllowed)
util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
return
}
var req api.PutApiGroupsIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if *req.Name == "" {
http.Error(w, "group name shouldn't be empty", http.StatusUnprocessableEntity)
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
return
}
@ -102,53 +104,55 @@ func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
if err := h.accountManager.SaveGroup(account.Id, &group); err != nil {
log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
writeJSONObject(w, toGroupResponse(account, &group))
util.WriteJSONObject(w, toGroupResponse(account, &group))
}
// PatchGroupHandler handles patch updates to a group identified by a given ID
func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
groupID := vars["id"]
if len(groupID) == 0 {
http.Error(w, "invalid group Id", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
_, ok := account.Groups[groupID]
if !ok {
http.Error(w, fmt.Sprintf("couldn't find group id %s", groupID), http.StatusNotFound)
util.WriteError(status.Errorf(status.NotFound, "couldn't find group ID %s", groupID), w)
return
}
allGroup, err := account.GetGroupAll()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
util.WriteError(err, w)
return
}
if allGroup.ID == groupID {
http.Error(w, "updating group ALL is not allowed", http.StatusMethodNotAllowed)
util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
return
}
var req api.PatchApiGroupsIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if len(req) == 0 {
http.Error(w, "no patch instruction received", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
return
}
@ -158,13 +162,13 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
switch patch.Path {
case api.GroupPatchOperationPathName:
if patch.Op != api.GroupPatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Name field only accepts replace operation, got %s", patch.Op),
http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument,
"name field only accepts replace operation, got %s", patch.Op), w)
return
}
if len(patch.Value) == 0 || patch.Value[0] == "" {
http.Error(w, "Group name shouldn't be empty", http.StatusUnprocessableEntity)
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
return
}
@ -193,53 +197,43 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
Values: peerKeys,
})
default:
http.Error(w, "invalid operation, \"%s\", for Peers field", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument,
"invalid operation, \"%v\", for Peers field", patch.Op), w)
return
}
default:
http.Error(w, "invalid patch path", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
return
}
}
group, err := h.accountManager.UpdateGroup(account.Id, groupID, operations)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.Internal {
http.Error(w, errStatus.String(), http.StatusInternalServerError)
util.WriteError(err, w)
return
}
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, errStatus.String(), http.StatusNotFound)
return
}
log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
writeJSONObject(w, toGroupResponse(account, group))
util.WriteJSONObject(w, toGroupResponse(account, group))
}
// CreateGroupHandler handles group creation request
func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
var req api.PostApiGroupsJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if req.Name == "" {
http.Error(w, "Group name shouldn't be empty", http.StatusUnprocessableEntity)
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
return
}
@ -249,55 +243,57 @@ func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
Peers: peerIPsToKeys(account, req.Peers),
}
if err := h.accountManager.SaveGroup(account.Id, &group); err != nil {
log.Errorf("failed creating group \"%s\" under account %s %v", req.Name, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
err = h.accountManager.SaveGroup(account.Id, &group)
if err != nil {
util.WriteError(err, w)
return
}
writeJSONObject(w, toGroupResponse(account, &group))
util.WriteJSONObject(w, toGroupResponse(account, &group))
}
// DeleteGroupHandler handles group deletion request
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
aID := account.Id
groupID := mux.Vars(r)["id"]
if len(groupID) == 0 {
http.Error(w, "invalid group ID", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
allGroup, err := account.GetGroupAll()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
util.WriteError(err, w)
return
}
if allGroup.ID == groupID {
http.Error(w, "deleting group ALL is not allowed", http.StatusMethodNotAllowed)
util.WriteError(status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
return
}
if err := h.accountManager.DeleteGroup(aID, groupID); err != nil {
log.Errorf("failed delete group %s under account %s %v", groupID, aID, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
err = h.accountManager.DeleteGroup(aID, groupID)
if err != nil {
util.WriteError(err, w)
return
}
writeJSONObject(w, "")
util.WriteJSONObject(w, "")
}
// GetGroupHandler returns a group
func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
@ -305,19 +301,22 @@ func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
case http.MethodGet:
groupID := mux.Vars(r)["id"]
if len(groupID) == 0 {
http.Error(w, "invalid group ID", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
group, err := h.accountManager.GetGroup(account.Id, groupID)
if err != nil {
http.Error(w, "group not found", http.StatusNotFound)
util.WriteError(err, w)
return
}
writeJSONObject(w, toGroupResponse(account, group))
util.WriteJSONObject(w, toGroupResponse(account, group))
default:
http.Error(w, "", http.StatusNotFound)
if err != nil {
util.WriteError(status.Errorf(status.NotFound, "HTTP method not found"), w)
return
}
}
}

View File

@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"io"
"net"
"net/http"
@ -36,7 +37,7 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
},
GetGroupFunc: func(_, groupID string) (*server.Group, error) {
if groupID != "idofthegroup" {
return nil, fmt.Errorf("not found")
return nil, status.Errorf(status.NotFound, "not found")
}
return &server.Group{
ID: "idofthegroup",
@ -67,7 +68,7 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
}
return nil, fmt.Errorf("peer not found")
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
@ -78,7 +79,7 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
Groups: map[string]*server.Group{
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}},
}, nil
}, user, nil
},
},
authAudience: "",
@ -223,7 +224,7 @@ func TestWriteGroup(t *testing.T) {
requestPath: "/api/groups/id-all",
requestBody: bytes.NewBuffer(
[]byte(`{"Name":"super"}`)),
expectedStatus: http.StatusMethodNotAllowed,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@ -244,7 +245,7 @@ func TestWriteGroup(t *testing.T) {
requestPath: "/api/groups/id-existed",
requestBody: bytes.NewBuffer(
[]byte(`[{"op":"insert","path":"name","value":[""]}]`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{

View File

@ -1,7 +1,8 @@
package middleware
import (
"fmt"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"net/http"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@ -33,14 +34,15 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
ok, err := a.isUserAdmin(jwtClaims)
if err != nil {
http.Error(w, fmt.Sprintf("error get user from JWT: %v", err), http.StatusUnauthorized)
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
return
}
if !ok {
switch r.Method {
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
http.Error(w, "user is not admin", http.StatusForbidden)
util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w)
return
}
}

View File

@ -66,7 +66,6 @@ func getPemKeys(keysLocation string) (*Jwks, error) {
var jwks = &Jwks{}
err = json.NewDecoder(resp.Body).Decode(jwks)
if err != nil {
return jwks, err
}

View File

@ -5,6 +5,8 @@ import (
"errors"
"fmt"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"log"
"net/http"
"strings"
@ -57,7 +59,7 @@ type JWTMiddleware struct {
}
func OnError(w http.ResponseWriter, r *http.Request, err string) {
http.Error(w, err, http.StatusUnauthorized)
util.WriteError(status.Errorf(status.Unauthorized, ""), w)
}
// New constructs a new Secure instance with supplied options.

View File

@ -7,7 +7,9 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
"net/http"
)
@ -30,7 +32,8 @@ func NewNameservers(accountManager server.AccountManager, authAudience string) *
// GetAllNameserversHandler returns the list of nameserver groups for the account
func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -39,7 +42,7 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re
nsGroups, err := h.accountManager.ListNameServerGroups(account.Id)
if err != nil {
toHTTPError(err, w)
util.WriteError(err, w)
return
}
@ -48,64 +51,67 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re
apiNameservers = append(apiNameservers, toNameserverGroupResponse(r))
}
writeJSONObject(w, apiNameservers)
util.WriteJSONObject(w, apiNameservers)
}
// CreateNameserverGroupHandler handles nameserver group creation request
func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
var req api.PostApiDnsNameserversJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
nsList, err := toServerNSList(req.Nameservers)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
return
}
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled)
if err != nil {
toHTTPError(err, w)
util.WriteError(err, w)
return
}
resp := toNameserverGroupResponse(nsGroup)
writeJSONObject(w, &resp)
util.WriteJSONObject(w, &resp)
}
// UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID
func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
var req api.PutApiDnsNameserversIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
nsList, err := toServerNSList(req.Nameservers)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
return
}
@ -122,41 +128,42 @@ func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *htt
err = h.accountManager.SaveNameServerGroup(account.Id, updatedNSGroup)
if err != nil {
toHTTPError(err, w)
util.WriteError(err, w)
return
}
resp := toNameserverGroupResponse(updatedNSGroup)
writeJSONObject(w, &resp)
util.WriteJSONObject(w, &resp)
}
// PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID
func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
var req api.PatchApiDnsNameserversIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
var operations []server.NameServerGroupUpdateOperation
for _, patch := range req {
if patch.Op != api.NameserverGroupPatchOperationOpReplace {
http.Error(w, fmt.Sprintf("nameserver groups only accepts replace operations, got %s", patch.Op),
http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument,
"nameserver groups only accepts replace operations, got %s", patch.Op), w)
return
}
switch patch.Path {
@ -196,49 +203,50 @@ func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http
Values: patch.Value,
})
default:
http.Error(w, "invalid patch path", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
return
}
}
updatedNSGroup, err := h.accountManager.UpdateNameServerGroup(account.Id, nsGroupID, operations)
if err != nil {
toHTTPError(err, w)
util.WriteError(err, w)
return
}
resp := toNameserverGroupResponse(updatedNSGroup)
writeJSONObject(w, &resp)
util.WriteJSONObject(w, &resp)
}
// DeleteNameserverGroupHandler handles nameserver group deletion request
func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID)
if err != nil {
toHTTPError(err, w)
util.WriteError(err, w)
return
}
writeJSONObject(w, "")
util.WriteJSONObject(w, "")
}
// GetNameserverGroupHandler handles a nameserver group Get request identified by ID
func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -247,19 +255,19 @@ func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.R
nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, nsGroupID)
if err != nil {
toHTTPError(err, w)
util.WriteError(err, w)
return
}
resp := toNameserverGroupResponse(nsGroup)
writeJSONObject(w, &resp)
util.WriteJSONObject(w, &resp)
}

View File

@ -5,9 +5,8 @@ import (
"encoding/json"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"io"
"net/http"
"net/http/httptest"
@ -62,7 +61,7 @@ func initNameserversTestData() *Nameservers {
if nsGroupID == existingNSGroupID {
return baseExistingNSGroup.Copy(), nil
}
return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID)
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
},
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
return &nbdns.NameServerGroup{
@ -83,12 +82,12 @@ func initNameserversTestData() *Nameservers {
if nsGroupToSave.ID == existingNSGroupID {
return nil
}
return status.Errorf(codes.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
},
UpdateNameServerGroupFunc: func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
nsGroupToUpdate := baseExistingNSGroup.Copy()
if nsGroupID != nsGroupToUpdate.ID {
return nil, status.Errorf(codes.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
}
for _, operation := range operations {
switch operation.Type {
@ -110,8 +109,8 @@ func initNameserversTestData() *Nameservers {
}
return nsGroupToUpdate, nil
},
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) {
return testingNSAccount, nil
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingNSAccount, nil, nil
},
},
authAudience: "",
@ -181,7 +180,7 @@ func TestNameserversHandlers(t *testing.T) {
requestPath: "/api/dns/nameservers",
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@ -223,7 +222,7 @@ func TestNameserversHandlers(t *testing.T) {
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{

View File

@ -6,8 +6,9 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/status"
"net/http"
)
@ -28,50 +29,47 @@ func NewPeers(accountManager server.AccountManager, authAudience string) *Peers
func (h *Peers) updatePeer(account *server.Account, peer *server.Peer, w http.ResponseWriter, r *http.Request) {
req := &api.PutApiPeersIdJSONBody{}
peerIp := peer.IP
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
update := &server.Peer{Key: peer.Key, SSHEnabled: req.SshEnabled, Name: req.Name}
peer, err = h.accountManager.UpdatePeer(account.Id, update)
if err != nil {
log.Errorf("failed updating peer %s under account %s %v", peerIp, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
writeJSONObject(w, toPeerResponse(peer, account))
util.WriteJSONObject(w, toPeerResponse(peer, account))
}
func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseWriter, r *http.Request) {
_, err := h.accountManager.DeletePeer(accountId, peer.Key)
if err != nil {
log.Errorf("failed deleteing peer %s, %v", peer.IP, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
writeJSONObject(w, "")
util.WriteJSONObject(w, "")
}
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
peerId := vars["id"] //effectively peer IP address
if len(peerId) == 0 {
http.Error(w, "invalid peer Id", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return
}
peer, err := h.accountManager.GetPeerByIP(account.Id, peerId)
if err != nil {
http.Error(w, "peer not found", http.StatusNotFound)
util.WriteError(err, w)
return
}
@ -83,11 +81,11 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
h.updatePeer(account, peer, w, r)
return
case http.MethodGet:
writeJSONObject(w, toPeerResponse(peer, account))
util.WriteJSONObject(w, toPeerResponse(peer, account))
return
default:
http.Error(w, "", http.StatusNotFound)
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
}
}
@ -95,15 +93,16 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
peers, err := h.accountManager.GetPeers(account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
return
}
@ -111,10 +110,10 @@ func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
for _, peer := range peers {
respBody = append(respBody, toPeerResponse(peer, account))
}
writeJSONObject(w, respBody)
util.WriteJSONObject(w, respBody)
return
default:
http.Error(w, "", http.StatusNotFound)
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
}
}

View File

@ -22,7 +22,8 @@ func initTestMetaData(peers ...*server.Peer) *Peers {
GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) {
return peers, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
@ -30,9 +31,9 @@ func initTestMetaData(peers ...*server.Peer) *Peers {
"test_peer": peers[0],
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
"test_user": user,
},
}, nil
}, user, nil
},
},
authAudience: "",

View File

@ -2,15 +2,13 @@ package http
import (
"encoding/json"
"fmt"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http"
"unicode/utf8"
)
@ -33,25 +31,16 @@ func NewRoutes(accountManager server.AccountManager, authAudience string) *Route
// GetAllRoutesHandler returns the list of routes for the account
func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
routes, err := h.accountManager.ListRoutes(account.Id, user.Id)
if err != nil {
log.Error(err)
if e, ok := server.FromError(err); ok {
switch e.Type() {
case server.PermissionDenied:
http.Error(w, e.Error(), http.StatusForbidden)
return
default:
}
}
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
apiRoutes := make([]*api.Route, 0)
@ -59,20 +48,22 @@ func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
apiRoutes = append(apiRoutes, toRouteResponse(account, r))
}
writeJSONObject(w, apiRoutes)
util.WriteJSONObject(w, apiRoutes)
}
// CreateRouteHandler handles route creation request
func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
var req api.PostApiRoutesJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
@ -80,8 +71,7 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
if req.Peer != "" {
peer, err := h.accountManager.GetPeerByIP(account.Id, req.Peer)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusUnprocessableEntity)
util.WriteError(err, w)
return
}
peerKey = peer.Key
@ -89,57 +79,60 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
_, newPrefix, err := route.ParseNetwork(req.Network)
if err != nil {
http.Error(w, fmt.Sprintf("couldn't parse update prefix %s", req.Network), http.StatusBadRequest)
util.WriteError(err, w)
return
}
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
http.Error(w, fmt.Sprintf("identifier should be between 1 and %d", route.MaxNetIDChar), http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d",
route.MaxNetIDChar), w)
return
}
newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix.String(), peerKey, req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Enabled)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
resp := toRouteResponse(account, newRoute)
writeJSONObject(w, &resp)
util.WriteJSONObject(w, &resp)
}
// UpdateRouteHandler handles update to a route identified by a given ID
func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
routeID := vars["id"]
if len(routeID) == 0 {
http.Error(w, "invalid route Id", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
}
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
if err != nil {
http.Error(w, fmt.Sprintf("couldn't find route for ID %s", routeID), http.StatusNotFound)
util.WriteError(err, w)
return
}
var req api.PutApiRoutesIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
prefixType, newPrefix, err := route.ParseNetwork(req.Network)
if err != nil {
http.Error(w, fmt.Sprintf("couldn't parse update prefix %s for route ID %s", req.Network, routeID), http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "couldn't parse update prefix %s for route ID %s",
req.Network, routeID), w)
return
}
@ -147,15 +140,15 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
if req.Peer != "" {
peer, err := h.accountManager.GetPeerByIP(account.Id, req.Peer)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusUnprocessableEntity)
util.WriteError(err, w)
return
}
peerKey = peer.Key
}
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
http.Error(w, fmt.Sprintf("identifier should be between 1 and %d", route.MaxNetIDChar), http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument,
"identifier should be between 1 and %d", route.MaxNetIDChar), w)
return
}
@ -173,46 +166,46 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
err = h.accountManager.SaveRoute(account.Id, newRoute)
if err != nil {
log.Errorf("failed updating route \"%s\" under account %s %v", routeID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
resp := toRouteResponse(account, newRoute)
writeJSONObject(w, &resp)
util.WriteJSONObject(w, &resp)
}
// PatchRouteHandler handles patch updates to a route identified by a given ID
func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
routeID := vars["id"]
if len(routeID) == 0 {
http.Error(w, "invalid route ID", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
}
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
if err != nil {
log.Error(err)
http.Error(w, fmt.Sprintf("couldn't find route ID %s", routeID), http.StatusNotFound)
util.WriteError(err, w)
return
}
var req api.PatchApiRoutesIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if len(req) == 0 {
http.Error(w, "no patch instruction received", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
return
}
@ -222,8 +215,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
switch patch.Path {
case api.RoutePatchOperationPathNetwork:
if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Network field only accepts replace operation, got %s", patch.Op),
http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument,
"network field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RouteUpdateOperation{
@ -232,8 +225,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
})
case api.RoutePatchOperationPathDescription:
if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Description field only accepts replace operation, got %s", patch.Op),
http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument,
"description field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RouteUpdateOperation{
@ -242,8 +235,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
})
case api.RoutePatchOperationPathNetworkId:
if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Network Identifier field only accepts replace operation, got %s", patch.Op),
http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument,
"network Identifier field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RouteUpdateOperation{
@ -252,21 +245,20 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
})
case api.RoutePatchOperationPathPeer:
if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Peer field only accepts replace operation, got %s", patch.Op),
http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument,
"peer field only accepts replace operation, got %s", patch.Op), w)
return
}
if len(patch.Value) > 1 {
http.Error(w, fmt.Sprintf("Value field only accepts 1 value, got %d", len(patch.Value)),
http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument,
"value field only accepts 1 value, got %d", len(patch.Value)), w)
return
}
peerValue := patch.Value
if patch.Value[0] != "" {
peer, err := h.accountManager.GetPeerByIP(account.Id, patch.Value[0])
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusUnprocessableEntity)
util.WriteError(err, w)
return
}
peerValue = []string{peer.Key}
@ -277,8 +269,9 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
})
case api.RoutePatchOperationPathMetric:
if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Metric field only accepts replace operation, got %s", patch.Op),
http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument,
"metric field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RouteUpdateOperation{
@ -287,8 +280,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
})
case api.RoutePatchOperationPathMasquerade:
if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Masquerade field only accepts replace operation, got %s", patch.Op),
http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument,
"masquerade field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RouteUpdateOperation{
@ -297,8 +290,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
})
case api.RoutePatchOperationPathEnabled:
if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Enabled field only accepts replace operation, got %s", patch.Op),
http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument,
"enabled field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RouteUpdateOperation{
@ -306,90 +299,68 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
Values: patch.Value,
})
default:
http.Error(w, "invalid patch path", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
return
}
}
route, err := h.accountManager.UpdateRoute(account.Id, routeID, operations)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.Internal {
http.Error(w, errStatus.String(), http.StatusInternalServerError)
return
}
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, errStatus.String(), http.StatusNotFound)
return
}
if ok && errStatus.Code() == codes.InvalidArgument {
http.Error(w, errStatus.String(), http.StatusBadRequest)
return
}
log.Errorf("failed updating route %s under account %s %v", routeID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
resp := toRouteResponse(account, route)
writeJSONObject(w, &resp)
util.WriteJSONObject(w, &resp)
}
// DeleteRouteHandler handles route deletion request
func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
routeID := mux.Vars(r)["id"]
if len(routeID) == 0 {
http.Error(w, "invalid route ID", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
}
err = h.accountManager.DeleteRoute(account.Id, routeID)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, fmt.Sprintf("route %s not found under account %s", routeID, account.Id), http.StatusNotFound)
return
}
log.Errorf("failed delete route %s under account %s %v", routeID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
writeJSONObject(w, "")
util.WriteJSONObject(w, "")
}
// GetRouteHandler handles a route Get request identified by ID
func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
routeID := mux.Vars(r)["id"]
if len(routeID) == 0 {
http.Error(w, "invalid route ID", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
}
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id)
if err != nil {
http.Error(w, "route not found", http.StatusNotFound)
util.WriteError(status.Errorf(status.NotFound, "route not found"), w)
return
}
writeJSONObject(w, toRouteResponse(account, foundRoute))
util.WriteJSONObject(w, toRouteResponse(account, foundRoute))
}
func toRouteResponse(account *server.Account, serverRoute *route.Route) *api.Route {

View File

@ -5,9 +5,8 @@ import (
"encoding/json"
"fmt"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"io"
"net/http"
"net/http/httptest"
@ -63,7 +62,7 @@ func initRoutesTestData() *Routes {
if routeID == existingRouteID {
return baseExistingRoute, nil
}
return nil, status.Errorf(codes.NotFound, "route with ID %s not found", routeID)
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
},
CreateRouteFunc: func(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) {
networkType, p, _ := route.ParseNetwork(network)
@ -83,13 +82,13 @@ func initRoutesTestData() *Routes {
},
DeleteRouteFunc: func(_ string, peerIP string) error {
if peerIP != existingRouteID {
return status.Errorf(codes.NotFound, "Peer with ID %s not found", peerIP)
return status.Errorf(status.NotFound, "Peer with ID %s not found", peerIP)
}
return nil
},
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
if peerIP != existingPeerID {
return nil, status.Errorf(codes.NotFound, "Peer with ID %s not found", peerIP)
return nil, status.Errorf(status.NotFound, "Peer with ID %s not found", peerIP)
}
return &server.Peer{
Key: existingPeerKey,
@ -99,7 +98,7 @@ func initRoutesTestData() *Routes {
UpdateRouteFunc: func(_ string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) {
routeToUpdate := baseExistingRoute
if routeID != routeToUpdate.ID {
return nil, status.Errorf(codes.NotFound, "route %s no longer exists", routeID)
return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID)
}
for _, operation := range operations {
switch operation.Type {
@ -123,8 +122,8 @@ func initRoutesTestData() *Routes {
}
return routeToUpdate, nil
},
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) {
return testingAccount, nil
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingAccount, testingAccount.Users["test_user"], nil
},
},
authAudience: "",
@ -201,15 +200,15 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPost,
requestPath: "/api/routes",
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", notFoundPeerID)),
expectedStatus: http.StatusUnprocessableEntity,
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
{
name: "POST Not Invalid Network Identifier",
name: "POST Invalid Network Identifier",
requestType: http.MethodPost,
requestPath: "/api/routes",
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"12345678901234567890qwertyuiopqwertyuiop1\",\"Peer\":\"%s\"}", existingPeerID)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@ -217,7 +216,7 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPost,
requestPath: "/api/routes",
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/34\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", existingPeerID)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@ -251,7 +250,7 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", notFoundPeerID)),
expectedStatus: http.StatusUnprocessableEntity,
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
{
@ -259,7 +258,7 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"12345678901234567890qwertyuiopqwertyuiop1\",\"Peer\":\"%s\"}", existingPeerID)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@ -267,7 +266,7 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/34\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", existingPeerID)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@ -312,7 +311,7 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPatch,
requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("[{\"op\":\"replace\",\"path\":\"peer\",\"value\":[\"%s\"]}]", notFoundPeerID)),
expectedStatus: http.StatusUnprocessableEntity,
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
{

View File

@ -2,15 +2,13 @@ package http
import (
"encoding/json"
"fmt"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http"
)
@ -31,25 +29,16 @@ func NewRules(accountManager server.AccountManager, authAudience string) *Rules
// GetAllRulesHandler list for the account
func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
accountRules, err := h.accountManager.ListRules(account.Id, user.Id)
if err != nil {
log.Error(err)
if e, ok := server.FromError(err); ok {
switch e.Type() {
case server.PermissionDenied:
http.Error(w, e.Error(), http.StatusForbidden)
return
default:
}
}
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
rules := []*api.Rule{}
@ -57,38 +46,39 @@ func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
rules = append(rules, toRuleResponse(account, r))
}
writeJSONObject(w, rules)
util.WriteJSONObject(w, rules)
}
// UpdateRuleHandler handles update to a rule identified by a given ID
func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
ruleID := vars["id"]
if len(ruleID) == 0 {
http.Error(w, "invalid rule Id", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return
}
_, ok := account.Rules[ruleID]
if !ok {
http.Error(w, fmt.Sprintf("couldn't find rule id %s", ruleID), http.StatusNotFound)
util.WriteError(status.Errorf(status.NotFound, "couldn't find rule id %s", ruleID), w)
return
}
var req api.PutApiRulesIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
}
if req.Name == "" {
http.Error(w, "Rule name shouldn't be empty", http.StatusUnprocessableEntity)
util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w)
return
}
@ -115,50 +105,52 @@ func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
case server.TrafficFlowBidirectString:
rule.Flow = server.TrafficFlowBidirect
default:
http.Error(w, "unknown flow type", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "unknown flow type"), w)
return
}
if err := h.accountManager.SaveRule(account.Id, &rule); err != nil {
log.Errorf("failed updating rule \"%s\" under account %s %v", ruleID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
err = h.accountManager.SaveRule(account.Id, &rule)
if err != nil {
util.WriteError(err, w)
return
}
resp := toRuleResponse(account, &rule)
writeJSONObject(w, &resp)
util.WriteJSONObject(w, &resp)
}
// PatchRuleHandler handles patch updates to a rule identified by a given ID
func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
ruleID := vars["id"]
if len(ruleID) == 0 {
http.Error(w, "invalid rule Id", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return
}
_, ok := account.Rules[ruleID]
if !ok {
http.Error(w, fmt.Sprintf("couldn't find rule id %s", ruleID), http.StatusNotFound)
util.WriteError(status.Errorf(status.NotFound, "couldn't find rule ID %s", ruleID), w)
return
}
var req api.PatchApiRulesIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if len(req) == 0 {
http.Error(w, "no patch instruction received", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
return
}
@ -168,12 +160,12 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
switch patch.Path {
case api.RulePatchOperationPathName:
if patch.Op != api.RulePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Name field only accepts replace operation, got %s", patch.Op),
http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument,
"name field only accepts replace operation, got %s", patch.Op), w)
return
}
if len(patch.Value) == 0 || patch.Value[0] == "" {
http.Error(w, "Rule name shouldn't be empty", http.StatusUnprocessableEntity)
util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w)
return
}
operations = append(operations, server.RuleUpdateOperation{
@ -182,8 +174,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
})
case api.RulePatchOperationPathDescription:
if patch.Op != api.RulePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Description field only accepts replace operation, got %s", patch.Op),
http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument,
"description field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RuleUpdateOperation{
@ -192,8 +184,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
})
case api.RulePatchOperationPathFlow:
if patch.Op != api.RulePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Flow field only accepts replace operation, got %s", patch.Op),
http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument,
"flow field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RuleUpdateOperation{
@ -202,8 +194,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
})
case api.RulePatchOperationPathDisabled:
if patch.Op != api.RulePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Disabled field only accepts replace operation, got %s", patch.Op),
http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument,
"disabled field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RuleUpdateOperation{
@ -228,7 +220,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
Values: patch.Value,
})
default:
http.Error(w, "invalid operation, \"%s\", for Source field", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument,
"invalid operation \"%s\" on Source field", patch.Op), w)
return
}
case api.RulePatchOperationPathDestinations:
@ -249,11 +242,12 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
Values: patch.Value,
})
default:
http.Error(w, "invalid operation, \"%s\", for Destination field", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument,
"invalid operation \"%s\" on Destination field", patch.Op), w)
return
}
default:
http.Error(w, "invalid patch path", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
return
}
}
@ -261,48 +255,33 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
rule, err := h.accountManager.UpdateRule(account.Id, ruleID, operations)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.Internal {
http.Error(w, errStatus.String(), http.StatusInternalServerError)
return
}
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, errStatus.String(), http.StatusNotFound)
return
}
if ok && errStatus.Code() == codes.InvalidArgument {
http.Error(w, errStatus.String(), http.StatusBadRequest)
return
}
log.Errorf("failed updating rule %s under account %s %v", ruleID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
resp := toRuleResponse(account, rule)
writeJSONObject(w, &resp)
util.WriteJSONObject(w, &resp)
}
// CreateRuleHandler handles rule creation request
func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
var req api.PostApiRulesJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if req.Name == "" {
http.Error(w, "Rule name shouldn't be empty", http.StatusUnprocessableEntity)
util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w)
return
}
@ -329,50 +308,52 @@ func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
case server.TrafficFlowBidirectString:
rule.Flow = server.TrafficFlowBidirect
default:
http.Error(w, "unknown flow type", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "unknown flow type"), w)
return
}
if err := h.accountManager.SaveRule(account.Id, &rule); err != nil {
log.Errorf("failed creating rule \"%s\" under account %s %v", req.Name, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
err = h.accountManager.SaveRule(account.Id, &rule)
if err != nil {
util.WriteError(err, w)
return
}
resp := toRuleResponse(account, &rule)
writeJSONObject(w, &resp)
util.WriteJSONObject(w, &resp)
}
// DeleteRuleHandler handles rule deletion request
func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
aID := account.Id
rID := mux.Vars(r)["id"]
if len(rID) == 0 {
http.Error(w, "invalid rule ID", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return
}
if err := h.accountManager.DeleteRule(aID, rID); err != nil {
log.Errorf("failed delete rule %s under account %s %v", rID, aID, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
err = h.accountManager.DeleteRule(aID, rID)
if err != nil {
util.WriteError(err, w)
return
}
writeJSONObject(w, "")
util.WriteJSONObject(w, "")
}
// GetRuleHandler handles a group Get request identified by ID
func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
@ -380,19 +361,19 @@ func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
case http.MethodGet:
ruleID := mux.Vars(r)["id"]
if len(ruleID) == 0 {
http.Error(w, "invalid rule ID", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return
}
rule, err := h.accountManager.GetRule(account.Id, ruleID, user.Id)
if err != nil {
http.Error(w, "rule not found", http.StatusNotFound)
util.WriteError(status.Errorf(status.NotFound, "rule not found"), w)
return
}
writeJSONObject(w, toRuleResponse(account, rule))
util.WriteJSONObject(w, toRuleResponse(account, rule))
default:
http.Error(w, "", http.StatusNotFound)
util.WriteError(status.Errorf(status.NotFound, "method not found"), w)
}
}

View File

@ -66,7 +66,8 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
}
return &rule, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
@ -76,9 +77,9 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
"G": {ID: "G"},
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
"test_user": user,
},
}, nil
}, user, nil
},
},
authAudience: "",
@ -238,7 +239,7 @@ func TestRulesWriteRule(t *testing.T) {
requestPath: "/api/rules/id-existed",
requestBody: bytes.NewBuffer(
[]byte(`[{"op":"insert","path":"name","value":[""]}]`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{

View File

@ -2,14 +2,12 @@ package http
import (
"encoding/json"
"fmt"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/server/status"
"net/http"
"time"
)
@ -31,29 +29,28 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authAudience stri
// CreateSetupKeyHandler is a POST requests that creates a new SetupKey
func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
req := &api.PostApiSetupKeysJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if req.Name == "" {
http.Error(w, "Setup key name shouldn't be empty", http.StatusUnprocessableEntity)
util.WriteError(status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w)
return
}
if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable ||
server.SetupKeyType(req.Type) == server.SetupKeyOneOff) {
http.Error(w, "unknown setup key type "+string(req.Type), http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "unknown setup key type %s", string(req.Type)), w)
return
}
@ -62,17 +59,11 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
if req.AutoGroups == nil {
req.AutoGroups = []string{}
}
// newExpiresIn := time.Duration(req.ExpiresIn) * time.Second
// newKey.ExpiresAt = time.Now().Add(newExpiresIn)
setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
req.AutoGroups)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, "account not found", http.StatusNotFound)
return
}
http.Error(w, "failed adding setup key", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
@ -81,29 +72,23 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
// GetSetupKeyHandler is a GET request to get a SetupKey by ID
func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
keyID := vars["id"]
if len(keyID) == 0 {
http.Error(w, "invalid key Id", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
return
}
key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, fmt.Sprintf("setup key %s not found under account %s", keyID, account.Id), http.StatusNotFound)
return
}
log.Errorf("failed getting setup key %s under account %s %v", keyID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
@ -112,34 +97,34 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
// UpdateSetupKeyHandler is a PUT request to update server.SetupKey
func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
keyID := vars["id"]
if len(keyID) == 0 {
http.Error(w, "invalid key Id", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
return
}
req := &api.PutApiSetupKeysIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if req.Name == "" {
http.Error(w, fmt.Sprintf("setup key name field is invalid: %s", req.Name), http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
return
}
if req.AutoGroups == nil {
http.Error(w, fmt.Sprintf("setup key AutoGroups field is invalid: %s", req.AutoGroups), http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
return
}
@ -150,16 +135,8 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request
newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey)
if err != nil {
if e, ok := status.FromError(err); ok {
switch e.Code() {
case codes.NotFound:
http.Error(w, fmt.Sprintf("couldn't find setup key for ID %s", keyID), http.StatusNotFound)
default:
http.Error(w, "failed updating setup key", http.StatusInternalServerError)
}
}
util.WriteError(err, w)
return
}
writeSuccess(w, newKey)
@ -168,25 +145,25 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request
// GetAllSetupKeysHandler is a GET request that returns a list of SetupKey
func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
apiSetupKeys := make([]*api.SetupKey, 0)
for _, key := range setupKeys {
apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
}
writeJSONObject(w, apiSetupKeys)
util.WriteJSONObject(w, apiSetupKeys)
}
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
@ -194,7 +171,7 @@ func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(toResponseBody(key))
if err != nil {
http.Error(w, "failed handling request", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
}

View File

@ -6,9 +6,8 @@ import (
"fmt"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"io"
"net/http"
"net/http/httptest"
@ -32,7 +31,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
user *server.User) *SetupKeys {
return &SetupKeys{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{
Id: testAccountID,
Domain: "hotmail.com",
@ -45,7 +44,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
Groups: map[string]*server.Group{
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}},
}, nil
}, user, nil
},
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string) (*server.SetupKey, error) {
if keyName == newKey.Name || typ != newKey.Type {
@ -60,7 +59,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
case newKey.Id:
return newKey, nil
default:
return nil, status.Errorf(codes.NotFound, "key %s not found", keyID)
return nil, status.Errorf(status.NotFound, "key %s not found", keyID)
}
},
@ -68,7 +67,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
if key.Id == updatedSetupKey.Id {
return updatedSetupKey, nil
}
return nil, status.Errorf(codes.NotFound, "key %s not found", key.Id)
return nil, status.Errorf(status.NotFound, "key %s not found", key.Id)
},
ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) {

View File

@ -2,12 +2,10 @@ package http
import (
"encoding/json"
"fmt"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"net/http"
"github.com/netbirdio/netbird/management/server"
@ -31,33 +29,34 @@ func NewUserHandler(accountManager server.AccountManager, authAudience string) *
// UpdateUser is a PUT requests to update User data
func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
http.Error(w, "", http.StatusBadRequest)
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
userID := vars["id"]
if len(userID) == 0 {
http.Error(w, "invalid user ID", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
req := &api.PutApiUsersIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
userRole := server.StrRoleToUserRole(req.Role)
if userRole == server.UserRoleUnknown {
http.Error(w, "invalid user role", http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w)
return
}
@ -67,40 +66,36 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
AutoGroups: req.AutoGroups,
})
if err != nil {
if e, ok := status.FromError(err); ok {
switch e.Code() {
case codes.NotFound:
http.Error(w, fmt.Sprintf("couldn't find a user for ID %s", userID), http.StatusNotFound)
default:
http.Error(w, "failed to update user", http.StatusInternalServerError)
}
}
util.WriteError(err, w)
return
}
writeJSONObject(w, toUserResponse(newUser))
util.WriteJSONObject(w, toUserResponse(newUser))
}
// CreateUserHandler creates a User in the system with a status "invited" (effectively this is a user invite).
func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "", http.StatusNotFound)
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
log.Error(err)
util.WriteError(err, w)
return
}
req := &api.PostApiUsersJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown {
http.Error(w, "unknown user role "+req.Role, http.StatusBadRequest)
util.WriteError(status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w)
return
}
@ -111,37 +106,30 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request)
AutoGroups: req.AutoGroups,
})
if err != nil {
if e, ok := server.FromError(err); ok {
switch e.Type() {
case server.UserAlreadyExists:
http.Error(w, "You can't invite users with an existing NetBird account.", http.StatusPreconditionFailed)
return
default:
}
}
http.Error(w, "failed to invite", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
writeJSONObject(w, toUserResponse(newUser))
util.WriteJSONObject(w, toUserResponse(newUser))
}
// GetUsers returns a list of users of the account this user belongs to.
// It also gathers additional user data (like email and name) from the IDP manager.
func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "", http.StatusBadRequest)
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
util.WriteError(err, w)
return
}
@ -150,7 +138,7 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
users = append(users, toUserResponse(r))
}
writeJSONObject(w, users)
util.WriteJSONObject(w, users)
}
func toUserResponse(user *server.UserInfo) *api.User {

View File

@ -16,7 +16,7 @@ import (
func initUsers(user ...*server.User) *UserHandler {
return &UserHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
users := make(map[string]*server.User, 0)
for _, u := range user {
users[u.Id] = u
@ -25,7 +25,7 @@ func initUsers(user ...*server.User) *UserHandler {
Id: "12345",
Domain: "netbird.io",
Users: users,
}, nil
}, users[claims.UserId], nil
},
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0)
@ -66,7 +66,6 @@ func TestGetUsers(t *testing.T) {
expectedResult []*server.User
}{
{name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users/", expectedStatus: http.StatusOK, expectedResult: users},
{name: "WrongRequestMethod", requestType: http.MethodPost, requestPath: "/api/users/", expectedStatus: http.StatusBadRequest},
}
for _, tc := range tt {

View File

@ -1,97 +0,0 @@
package http
import (
"encoding/json"
"errors"
"fmt"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http"
"time"
)
// writeJSONObject simply writes object to the HTTP reponse in JSON format
func writeJSONObject(w http.ResponseWriter, obj interface{}) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
err := json.NewEncoder(w).Encode(obj)
if err != nil {
http.Error(w, "failed handling request", http.StatusInternalServerError)
return
}
}
// Duration is used strictly for JSON requests/responses due to duration marshalling issues
type Duration struct {
time.Duration
}
func (d Duration) MarshalJSON() ([]byte, error) {
return json.Marshal(d.String())
}
func (d *Duration) UnmarshalJSON(b []byte) error {
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
switch value := v.(type) {
case float64:
d.Duration = time.Duration(value)
return nil
case string:
var err error
d.Duration, err = time.ParseDuration(value)
if err != nil {
return err
}
return nil
default:
return errors.New("invalid duration")
}
}
func getJWTAccount(accountManager server.AccountManager,
jwtExtractor jwtclaims.ClaimsExtractor,
authAudience string, r *http.Request) (*server.Account, *server.User, error) {
claims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience)
account, err := accountManager.GetAccountFromToken(claims)
if err != nil {
return nil, nil, fmt.Errorf("failed getting account of a user %s: %v", claims.UserId, err)
}
user := account.Users[claims.UserId]
if user == nil {
// this is not really possible because we got an account by user ID
return nil, nil, fmt.Errorf("user %s not found", claims.UserId)
}
return account, user, nil
}
func toHTTPError(err error, w http.ResponseWriter) {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.Internal {
http.Error(w, errStatus.String(), http.StatusInternalServerError)
return
}
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, errStatus.String(), http.StatusNotFound)
return
}
if ok && errStatus.Code() == codes.InvalidArgument {
http.Error(w, errStatus.String(), http.StatusBadRequest)
return
}
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", errStatus.String())
log.Error(unhandledMSG)
http.Error(w, unhandledMSG, http.StatusInternalServerError)
}

View File

@ -0,0 +1,105 @@
package util
import (
"encoding/json"
"errors"
"fmt"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
"net/http"
"time"
)
// WriteJSONObject simply writes object to the HTTP reponse in JSON format
func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
err := json.NewEncoder(w).Encode(obj)
if err != nil {
WriteError(err, w)
return
}
}
// Duration is used strictly for JSON requests/responses due to duration marshalling issues
type Duration struct {
time.Duration
}
// MarshalJSON marshals the duration
func (d Duration) MarshalJSON() ([]byte, error) {
return json.Marshal(d.String())
}
// UnmarshalJSON unmarshals the duration
func (d *Duration) UnmarshalJSON(b []byte) error {
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
switch value := v.(type) {
case float64:
d.Duration = time.Duration(value)
return nil
case string:
var err error
d.Duration, err = time.ParseDuration(value)
if err != nil {
return err
}
return nil
default:
return errors.New("invalid duration")
}
}
// WriteErrorResponse prepares and writes an error response i nJSON
func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
type errorResponse struct {
Message string `json:"message"`
Code int `json:"code"`
}
w.WriteHeader(httpStatus)
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
err := json.NewEncoder(w).Encode(&errorResponse{
Message: errMsg,
Code: httpStatus,
})
if err != nil {
http.Error(w, "failed handling request", http.StatusInternalServerError)
}
}
// WriteError converts an error to an JSON error response.
// If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise
func WriteError(err error, w http.ResponseWriter) {
errStatus, ok := status.FromError(err)
httpStatus := http.StatusInternalServerError
msg := "internal server error"
if ok {
switch errStatus.Type() {
case status.UserAlreadyExists:
httpStatus = http.StatusConflict
case status.AlreadyExists:
httpStatus = http.StatusConflict
case status.PreconditionFailed:
httpStatus = http.StatusPreconditionFailed
case status.PermissionDenied:
httpStatus = http.StatusForbidden
case status.NotFound:
httpStatus = http.StatusNotFound
case status.Internal:
httpStatus = http.StatusInternalServerError
case status.InvalidArgument:
httpStatus = http.StatusUnprocessableEntity
default:
}
msg = err.Error()
} else {
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error())
log.Error(unhandledMSG)
}
WriteErrorResponse(msg, httpStatus, w)
}

View File

@ -59,7 +59,7 @@ type MockAccountManager struct {
DeleteNameServerGroupFunc func(accountID, nsGroupID string) error
ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(accountID string, key *server.UserInfo) (*server.UserInfo, error)
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error)
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
}
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
@ -113,7 +113,7 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
}
// GetAccountByUserOrAccountId mock implementation of GetAccountByUserOrAccountId from server.AccountManager interface
// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface
func (am *MockAccountManager) GetAccountByUserOrAccountID(
userId, accountId, domain string,
) (*server.Account, error) {
@ -462,11 +462,12 @@ func (am *MockAccountManager) CreateUser(accountID string, invite *server.UserIn
}
// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface
func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User,
error) {
if am.GetAccountFromTokenFunc != nil {
return am.GetAccountFromTokenFunc(claims)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
}
// GetPeers mocks GetPeers of the AccountManager interface

View File

@ -3,10 +3,9 @@ package server
import (
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/status"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"strconv"
"unicode/utf8"
)
@ -66,7 +65,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
nsGroup, found := account.NameServerGroups[nsGroupID]
@ -74,7 +73,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string)
return nsGroup.Copy(), nil
}
return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID)
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
}
// CreateNameServerGroup creates and saves a new nameserver group
@ -85,7 +84,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
newNSGroup := &nbdns.NameServerGroup{
@ -119,7 +118,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return newNSGroup.Copy(), status.Errorf(codes.Unavailable, "failed to update peers after create nameserver %s", name)
return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after create nameserver %s", name)
}
return newNSGroup.Copy(), nil
@ -132,12 +131,12 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupTo
defer unlock()
if nsGroupToSave == nil {
return status.Errorf(codes.InvalidArgument, "nameserver group provided is nil")
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
}
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
return err
}
err = validateNameServerGroup(true, nsGroupToSave, account)
@ -156,7 +155,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupTo
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return status.Errorf(codes.Unavailable, "failed to update peers after update nameserver %s", nsGroupToSave.Name)
return status.Errorf(status.Internal, "failed to update peers after update nameserver %s", nsGroupToSave.Name)
}
return nil
@ -170,16 +169,16 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
if len(operations) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "operations shouldn't be empty")
return nil, status.Errorf(status.InvalidArgument, "operations shouldn't be empty")
}
nsGroupToUpdate, ok := account.NameServerGroups[nsGroupID]
if !ok {
return nil, status.Errorf(codes.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
}
newNSGroup := nsGroupToUpdate.Copy()
@ -187,12 +186,12 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
for _, operation := range operations {
valuesCount := len(operation.Values)
if valuesCount < 1 {
return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String())
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String())
}
for _, value := range operation.Values {
if value == "" {
return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String())
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String())
}
}
switch operation.Type {
@ -200,7 +199,7 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
newNSGroup.Description = operation.Values[0]
case UpdateNameServerGroupName:
if valuesCount > 1 {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount)
return nil, status.Errorf(status.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount)
}
err = validateNSGroupName(operation.Values[0], nsGroupID, account.NameServerGroups)
if err != nil {
@ -230,13 +229,13 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
case UpdateNameServerGroupEnabled:
enabled, err := strconv.ParseBool(operation.Values[0])
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
}
newNSGroup.Enabled = enabled
case UpdateNameServerGroupPrimary:
primary, err := strconv.ParseBool(operation.Values[0])
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0])
return nil, status.Errorf(status.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0])
}
newNSGroup.Primary = primary
case UpdateNameServerGroupDomains:
@ -259,7 +258,7 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return newNSGroup.Copy(), status.Errorf(codes.Unavailable, "failed to update peers after update nameserver %s", newNSGroup.Name)
return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after update nameserver %s", newNSGroup.Name)
}
return newNSGroup.Copy(), nil
@ -273,7 +272,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID stri
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
return err
}
delete(account.NameServerGroups, nsGroupID)
@ -287,7 +286,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID stri
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return status.Errorf(codes.Unavailable, "failed to update peers after deleting nameserver %s", nsGroupID)
return status.Errorf(status.Internal, "failed to update peers after deleting nameserver %s", nsGroupID)
}
return nil
@ -301,7 +300,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdn
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups))
@ -318,7 +317,7 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
nsGroupID = nameserverGroup.ID
_, found := account.NameServerGroups[nsGroupID]
if !found {
return status.Errorf(codes.NotFound, "nameserver group with ID %s was not found", nsGroupID)
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
}
}
@ -347,17 +346,17 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
func validateDomainInput(primary bool, domains []string) error {
if !primary && len(domains) == 0 {
return status.Errorf(codes.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
" it should be primary or have at least one domain")
}
if primary && len(domains) != 0 {
return status.Errorf(codes.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+
return status.Errorf(status.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+
" you should set either primary or domain")
}
for _, domain := range domains {
_, valid := dns.IsDomainName(domain)
if !valid {
return status.Errorf(codes.InvalidArgument, "nameserver group got an invalid domain: %s", domain)
return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s", domain)
}
}
return nil
@ -365,12 +364,12 @@ func validateDomainInput(primary bool, domains []string) error {
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
return status.Errorf(codes.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
}
for _, nsGroup := range nsGroupMap {
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
return status.Errorf(codes.InvalidArgument, "a nameserver group with name %s already exist", name)
return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name)
}
}
@ -380,19 +379,19 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na
func validateNSList(list []nbdns.NameServer) error {
nsListLenght := len(list)
if nsListLenght == 0 || nsListLenght > 2 {
return status.Errorf(codes.InvalidArgument, "the list of nameservers should be 1 or 2, got %d", len(list))
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 2, got %d", len(list))
}
return nil
}
func validateGroups(list []string, groups map[string]*Group) error {
if len(list) == 0 {
return status.Errorf(codes.InvalidArgument, "the list of group IDs should not be empty")
return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty")
}
for _, id := range list {
if id == "" {
return status.Errorf(codes.InvalidArgument, "group ID should not be empty string")
return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
}
found := false
for groupID := range groups {
@ -402,7 +401,7 @@ func validateGroups(list []string, groups map[string]*Group) error {
}
}
if !found {
return status.Errorf(codes.InvalidArgument, "group id %s not found", id)
return status.Errorf(status.InvalidArgument, "group id %s not found", id)
}
}

View File

@ -3,6 +3,7 @@ package server
import (
"github.com/c-robinson/iplib"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
"github.com/rs/xid"
"math/rand"
@ -93,7 +94,7 @@ func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
ips, _ := generateIPs(&ipNet, takenIPMap)
if len(ips) == 0 {
return nil, Errorf(PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String())
return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String())
}
// pick a random IP

View File

@ -2,6 +2,7 @@ package server
import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/status"
"net"
"strings"
"time"
@ -9,8 +10,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// PeerSystemMeta is a metadata of a Peer machine system
@ -162,7 +161,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Pe
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
//TODO Peer.ID migration: we will need to replace search by ID here
@ -208,7 +207,7 @@ func (am *DefaultAccountManager) DeletePeer(accountID string, peerPubKey string)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
peer, err := account.FindPeerByPubKey(peerPubKey)
@ -258,7 +257,7 @@ func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
for _, peer := range account.Peers {
@ -267,7 +266,7 @@ func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*
}
}
return nil, status.Errorf(codes.NotFound, "peer with IP %s not found", peerIP)
return nil, status.Errorf(status.NotFound, "peer with IP %s not found", peerIP)
}
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
@ -275,7 +274,7 @@ func (am *DefaultAccountManager) GetNetworkMap(peerPubKey string) (*NetworkMap,
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil {
return nil, status.Errorf(codes.Internal, "Invalid peer key %s", peerPubKey)
return nil, err
}
aclPeers := am.getPeersByACL(account, peerPubKey)
@ -306,7 +305,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(peerPubKey string) (*Network, er
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil {
return nil, status.Errorf(codes.Internal, "invalid peer key %s", peerPubKey)
return nil, err
}
return account.Network.Copy(), err
@ -332,7 +331,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
account, err = am.Store.GetAccountBySetupKey(setupKey)
}
if err != nil {
return nil, Errorf(AccountNotFound, "failed adding new peer: account not found")
return nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
}
unlock := am.Store.AcquireAccountLock(account.Id)
@ -352,7 +351,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
}
if !sk.IsValid() {
return nil, Errorf(PreconditionFailed, "couldn't add peer: setup key is invalid")
return nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
}
account.SetupKeys[sk.Key] = sk.IncrementUsage()
@ -418,7 +417,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed adding peer")
return nil, err
}
return newPeer, nil
@ -563,13 +562,13 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) error {
for _, peer := range peers {
remotePeerNetworkMap, err := am.GetNetworkMap(peer.Key)
if err != nil {
return status.Errorf(codes.Internal, "unable to fetch network map for peer %s, error: %v", peer.Key, err)
return err
}
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap)
err = am.peersUpdateManager.SendUpdate(peer.Key, &UpdateMessage{Update: update})
if err != nil {
return status.Errorf(codes.Internal, "unable to send update for peer %s, error: %v", peer.Key, err)
return err
}
}

View File

@ -2,11 +2,10 @@ package server
import (
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/netip"
"strconv"
"unicode/utf8"
@ -66,7 +65,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
user, err := account.FindUser(userID)
@ -75,7 +74,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
}
if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes")
return nil, status.Errorf(status.PermissionDenied, "Only administrators can view Network Routes")
}
wantedRoute, found := account.Routes[routeID]
@ -83,7 +82,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
return wantedRoute, nil
}
return nil, status.Errorf(codes.NotFound, "route with ID %s not found", routeID)
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
}
// checkPrefixPeerExists checks the combination of prefix and peer id, if it exists returns an error, otherwise returns nil
@ -101,14 +100,14 @@ func (am *DefaultAccountManager) checkPrefixPeerExists(accountID, peer string, p
routesWithPrefix := account.GetRoutesByPrefix(prefix)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
return nil
}
return status.Errorf(codes.InvalidArgument, "failed to parse prefix %s", prefix.String())
return status.Errorf(status.InvalidArgument, "failed to parse prefix %s", prefix.String())
}
for _, prefixRoute := range routesWithPrefix {
if prefixRoute.Peer == peer {
return status.Errorf(codes.AlreadyExists, "failed a route with prefix %s and peer already exist", prefix.String())
return status.Errorf(status.AlreadyExists, "failed a route with prefix %s and peer already exist", prefix.String())
}
}
return nil
@ -121,13 +120,13 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
var newRoute route.Route
prefixType, newPrefix, err := route.ParseNetwork(network)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse IP %s", network)
return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", network)
}
err = am.checkPrefixPeerExists(accountID, peer, newPrefix)
if err != nil {
@ -137,16 +136,16 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de
if peer != "" {
_, peerExist := account.Peers[peer]
if !peerExist {
return nil, status.Errorf(codes.InvalidArgument, "failed to find Peer %s", peer)
return nil, status.Errorf(status.InvalidArgument, "failed to find Peer %s", peer)
}
}
if metric < route.MinMetric || metric > route.MaxMetric {
return nil, status.Errorf(codes.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
}
if utf8.RuneCountInString(netID) > route.MaxNetIDChar || netID == "" {
return nil, status.Errorf(codes.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
newRoute.Peer = peer
@ -173,7 +172,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de
err = am.updateAccountPeers(account)
if err != nil {
log.Error(err)
return &newRoute, status.Errorf(codes.Unavailable, "failed to update peers after create route %s", newPrefix)
return &newRoute, status.Errorf(status.Internal, "failed to update peers after create route %s", newPrefix)
}
return &newRoute, nil
}
@ -184,30 +183,30 @@ func (am *DefaultAccountManager) SaveRoute(accountID string, routeToSave *route.
defer unlock()
if routeToSave == nil {
return status.Errorf(codes.InvalidArgument, "route provided is nil")
return status.Errorf(status.InvalidArgument, "route provided is nil")
}
if !routeToSave.Network.IsValid() {
return status.Errorf(codes.InvalidArgument, "invalid Prefix %s", routeToSave.Network.String())
return status.Errorf(status.InvalidArgument, "invalid Prefix %s", routeToSave.Network.String())
}
if routeToSave.Metric < route.MinMetric || routeToSave.Metric > route.MaxMetric {
return status.Errorf(codes.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
}
if utf8.RuneCountInString(routeToSave.NetID) > route.MaxNetIDChar || routeToSave.NetID == "" {
return status.Errorf(codes.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
return err
}
if routeToSave.Peer != "" {
_, peerExist := account.Peers[routeToSave.Peer]
if !peerExist {
return status.Errorf(codes.InvalidArgument, "failed to find Peer %s", routeToSave.Peer)
return status.Errorf(status.InvalidArgument, "failed to find Peer %s", routeToSave.Peer)
}
}
@ -228,12 +227,12 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
routeToUpdate, ok := account.Routes[routeID]
if !ok {
return nil, status.Errorf(codes.NotFound, "route %s no longer exists", routeID)
return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID)
}
newRoute := routeToUpdate.Copy()
@ -241,7 +240,7 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
for _, operation := range operations {
if len(operation.Values) != 1 {
return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid number of values, it should be 1", operation.Type.String())
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be 1", operation.Type.String())
}
switch operation.Type {
@ -249,13 +248,13 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
newRoute.Description = operation.Values[0]
case UpdateRouteNetworkIdentifier:
if utf8.RuneCountInString(operation.Values[0]) > route.MaxNetIDChar || operation.Values[0] == "" {
return nil, status.Errorf(codes.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
newRoute.NetID = operation.Values[0]
case UpdateRouteNetwork:
prefixType, prefix, err := route.ParseNetwork(operation.Values[0])
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse IP %s", operation.Values[0])
return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", operation.Values[0])
}
err = am.checkPrefixPeerExists(accountID, routeToUpdate.Peer, prefix)
if err != nil {
@ -267,7 +266,7 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
if operation.Values[0] != "" {
_, peerExist := account.Peers[operation.Values[0]]
if !peerExist {
return nil, status.Errorf(codes.InvalidArgument, "failed to find Peer %s", operation.Values[0])
return nil, status.Errorf(status.InvalidArgument, "failed to find Peer %s", operation.Values[0])
}
}
@ -279,10 +278,10 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
case UpdateRouteMetric:
metric, err := strconv.Atoi(operation.Values[0])
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse metric %s, not int", operation.Values[0])
return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, not int", operation.Values[0])
}
if metric < route.MinMetric || metric > route.MaxMetric {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse metric %s, value should be %d > N < %d",
return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, value should be %d > N < %d",
operation.Values[0],
route.MinMetric,
route.MaxMetric,
@ -292,13 +291,13 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
case UpdateRouteMasquerade:
masquerade, err := strconv.ParseBool(operation.Values[0])
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse masquerade %s, not boolean", operation.Values[0])
return nil, status.Errorf(status.InvalidArgument, "failed to parse masquerade %s, not boolean", operation.Values[0])
}
newRoute.Masquerade = masquerade
case UpdateRouteEnabled:
enabled, err := strconv.ParseBool(operation.Values[0])
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
}
newRoute.Enabled = enabled
}
@ -313,7 +312,7 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
err = am.updateAccountPeers(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update account peers")
return nil, status.Errorf(status.Internal, "failed to update account peers")
}
return newRoute, nil
}
@ -325,7 +324,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error {
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
return err
}
delete(account.Routes, routeID)
@ -345,7 +344,7 @@ func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
user, err := account.FindUser(userID)
@ -354,7 +353,7 @@ func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.
}
if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes")
return nil, status.Errorf(status.PermissionDenied, "Only administrators can view Network Routes")
}
routes := make([]*route.Route, 0, len(account.Routes))

View File

@ -1,8 +1,7 @@
package server
import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/server/status"
"strings"
)
@ -95,7 +94,7 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rul
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
user, err := account.FindUser(userID)
@ -104,7 +103,7 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rul
}
if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "only admins are allowed to view rules")
return nil, status.Errorf(status.PermissionDenied, "only admins are allowed to view rules")
}
rule, ok := account.Rules[ruleID]
@ -112,7 +111,7 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rul
return rule, nil
}
return nil, status.Errorf(codes.NotFound, "rule with ID %s not found", ruleID)
return nil, status.Errorf(status.NotFound, "rule with ID %s not found", ruleID)
}
// SaveRule of ACL in the store
@ -122,7 +121,7 @@ func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error {
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
return err
}
account.Rules[rule.ID] = rule
@ -143,12 +142,12 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
ruleToUpdate, ok := account.Rules[ruleID]
if !ok {
return nil, status.Errorf(codes.NotFound, "rule %s no longer exists", ruleID)
return nil, status.Errorf(status.NotFound, "rule %s no longer exists", ruleID)
}
rule := ruleToUpdate.Copy()
@ -161,7 +160,7 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
rule.Description = operation.Values[0]
case UpdateRuleFlow:
if operation.Values[0] != TrafficFlowBidirectString {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse flow")
return nil, status.Errorf(status.InvalidArgument, "failed to parse flow")
}
rule.Flow = TrafficFlowBidirect
case UpdateRuleStatus:
@ -170,7 +169,7 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
} else if strings.ToLower(operation.Values[0]) == "false" {
rule.Disabled = false
} else {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse status")
return nil, status.Errorf(status.InvalidArgument, "failed to parse status")
}
case UpdateSourceGroups:
rule.Source = operation.Values
@ -204,7 +203,7 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
err = am.updateAccountPeers(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update account peers")
return nil, status.Errorf(status.Internal, "failed to update account peers")
}
return rule, nil
@ -217,7 +216,7 @@ func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error {
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
return err
}
delete(account.Rules, ruleID)
@ -237,7 +236,7 @@ func (am *DefaultAccountManager) ListRules(accountID, userID string) ([]*Rule, e
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
user, err := account.FindUser(userID)
@ -246,7 +245,7 @@ func (am *DefaultAccountManager) ListRules(accountID, userID string) ([]*Rule, e
}
if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "Only Administrators can view Access Rules")
return nil, status.Errorf(status.PermissionDenied, "Only Administrators can view Access Rules")
}
rules := make([]*Rule, 0, len(account.Rules))

View File

@ -1,10 +1,8 @@
package server
import (
"fmt"
"github.com/google/uuid"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/server/status"
"hash/fnv"
"strconv"
"strings"
@ -183,12 +181,12 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
for _, group := range autoGroups {
if _, ok := account.Groups[group]; !ok {
return nil, fmt.Errorf("group %s doesn't exist", group)
return nil, status.Errorf(status.NotFound, "group %s doesn't exist", group)
}
}
@ -197,7 +195,7 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed adding account key")
return nil, status.Errorf(status.Internal, "failed adding account key")
}
return setupKey, nil
@ -212,12 +210,12 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
defer unlock()
if keyToSave == nil {
return nil, status.Errorf(codes.InvalidArgument, "provided setup key to update is nil")
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
}
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
var oldKey *SetupKey
@ -228,7 +226,7 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
}
}
if oldKey == nil {
return nil, status.Errorf(codes.NotFound, "setup key not found")
return nil, status.Errorf(status.NotFound, "setup key not found")
}
// only auto groups, revoked status, and name can be updated for now
@ -253,7 +251,7 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
user, err := account.FindUser(userID)
@ -282,7 +280,7 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
user, err := account.FindUser(userID)
@ -298,7 +296,7 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*
}
}
if foundKey == nil {
return nil, status.Errorf(codes.NotFound, "setup key not found")
return nil, status.Errorf(status.NotFound, "setup key not found")
}
// the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file)

View File

@ -0,0 +1,72 @@
package status
import (
"fmt"
)
const (
// UserAlreadyExists indicates that user already exists
UserAlreadyExists Type = 1
// PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled
PreconditionFailed Type = 2
// PermissionDenied indicates that user has no permissions to view data
PermissionDenied Type = 3
// NotFound indicates that the object wasn't found in the system (or under a given Account)
NotFound Type = 4
// Internal indicates some generic internal error
Internal Type = 5
// InvalidArgument indicates some generic invalid argument error
InvalidArgument Type = 6
// AlreadyExists indicates a generic error when an object already exists in the system
AlreadyExists Type = 7
// Unauthorized indicates that user is not authorized
Unauthorized Type = 8
// BadRequest indicates that user is not authorized
BadRequest Type = 9
)
// Type is a type of the Error
type Type int32
// Error is an internal error
type Error struct {
ErrorType Type
Message string
}
// Type returns the Type of the error
func (e *Error) Type() Type {
return e.ErrorType
}
// Error is an error string
func (e *Error) Error() string {
return e.Message
}
// Errorf returns Error(ErrorType, fmt.Sprintf(format, a...)).
func Errorf(errorType Type, format string, a ...interface{}) error {
return &Error{
ErrorType: errorType,
Message: fmt.Sprintf(format, a...),
}
}
// FromError returns Error, true if the provided error is of type of Error. nil, false otherwise
func FromError(err error) (s *Error, ok bool) {
if err == nil {
return nil, true
}
if e, ok := err.(*Error); ok {
return e, true
}
return nil, false
}

View File

@ -3,8 +3,7 @@ package server
import (
"fmt"
"github.com/netbirdio/netbird/management/server/idp"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/server/status"
"strings"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@ -123,7 +122,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
defer unlock()
if am.idpManager == nil {
return nil, Errorf(PreconditionFailed, "IdP manager must be enabled to send user invites")
return nil, status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites")
}
if invite == nil {
@ -132,7 +131,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, Errorf(AccountNotFound, "account %s doesn't exist", accountID)
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
}
// check if the user is already registered with this email => reject
@ -142,7 +141,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
}
if user != nil {
return nil, Errorf(UserAlreadyExists, "user has an existing account")
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
}
users, err := am.idpManager.GetUserByEmail(invite.Email)
@ -151,7 +150,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
}
if len(users) > 0 {
return nil, Errorf(UserAlreadyExists, "user has an existing account")
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
}
idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID)
@ -188,25 +187,24 @@ func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*User
defer unlock()
if update == nil {
return nil, status.Errorf(codes.InvalidArgument, "provided user update is nil")
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
}
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, err
}
for _, newGroupID := range update.AutoGroups {
if _, ok := account.Groups[newGroupID]; !ok {
return nil,
status.Errorf(codes.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
newGroupID, update.Id)
}
}
oldUser := account.Users[update.Id]
if oldUser == nil {
return nil, status.Errorf(codes.NotFound, "update not found")
return nil, status.Errorf(status.NotFound, "update not found")
}
// only auto groups, revoked status, and name can be updated for now
@ -226,7 +224,7 @@ func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*User
return nil, err
}
if userData == nil {
return nil, status.Errorf(codes.NotFound, "user %s not found in the IdP", newUser.Id)
return nil, status.Errorf(status.NotFound, "user %s not found in the IdP", newUser.Id)
}
return newUser.toUserInfo(userData)
}
@ -242,14 +240,14 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
account, err := am.Store.GetAccountByUser(userID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err = am.newAccount(userID, lowerDomain)
if err != nil {
return nil, err
}
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed creating account")
return nil, err
}
} else {
// other error
@ -263,7 +261,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
account.Domain = lowerDomain
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed updating account with domain")
return nil, status.Errorf(status.Internal, "failed updating account with domain")
}
}
@ -272,14 +270,14 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
// IsUserAdmin flag for current user authenticated by JWT token
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
account, err := am.GetAccountFromToken(claims)
account, _, err := am.GetAccountFromToken(claims)
if err != nil {
return false, fmt.Errorf("get account: %v", err)
}
user, ok := account.Users[claims.UserId]
if !ok {
return false, fmt.Errorf("no such user")
return false, status.Errorf(status.NotFound, "user not found")
}
return user.Role == UserRoleAdmin, nil

View File

@ -1,8 +1,7 @@
package route
import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/server/status"
"net/netip"
)
@ -108,13 +107,13 @@ func (r *Route) IsEqual(other *Route) bool {
func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) {
prefix, err := netip.ParsePrefix(networkString)
if err != nil {
return InvalidNetwork, netip.Prefix{}, err
return InvalidNetwork, netip.Prefix{}, status.Errorf(status.InvalidArgument, "invalid network %s", networkString)
}
masked := prefix.Masked()
if !masked.IsValid() {
return InvalidNetwork, netip.Prefix{}, status.Errorf(codes.InvalidArgument, "invalid range %s", networkString)
return InvalidNetwork, netip.Prefix{}, status.Errorf(status.InvalidArgument, "invalid range %s", networkString)
}
if masked.Addr().Is6() {