Replace gRPC errors in business logic with internal ones (#558)

This commit is contained in:
Misha Bragin 2022-11-11 20:36:45 +01:00 committed by GitHub
parent 1db4027bea
commit 509d23c7cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 768 additions and 847 deletions

View File

@ -8,12 +8,11 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"math/rand" "math/rand"
"net" "net"
"net/netip" "net/netip"
@ -52,7 +51,7 @@ type AccountManager interface {
SaveUser(accountID string, key *User) (*UserInfo, error) SaveUser(accountID string, key *User) (*UserInfo, error)
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExists(accountId string) (*bool, error) AccountExists(accountId string) (*bool, error)
GetPeer(peerKey string) (*Peer, error) GetPeer(peerKey string) (*Peer, error)
@ -265,14 +264,14 @@ func (a *Account) FindPeerByPubKey(peerPubKey string) (*Peer, error) {
} }
} }
return nil, status.Errorf(codes.NotFound, "peer with the public key %s not found", peerPubKey) return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", peerPubKey)
} }
// FindUser looks for a given user in the Account or returns error if user wasn't found. // FindUser looks for a given user in the Account or returns error if user wasn't found.
func (a *Account) FindUser(userID string) (*User, error) { func (a *Account) FindUser(userID string) (*User, error) {
user := a.Users[userID] user := a.Users[userID]
if user == nil { if user == nil {
return nil, Errorf(UserNotFound, "user %s not found", userID) return nil, status.Errorf(status.NotFound, "user %s not found", userID)
} }
return user, nil return user, nil
@ -282,7 +281,7 @@ func (a *Account) FindUser(userID string) (*User, error) {
func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) { func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) {
key := a.SetupKeys[setupKey] key := a.SetupKeys[setupKey]
if key == nil { if key == nil {
return nil, Errorf(SetupKeyNotFound, "setup key not found") return nil, status.Errorf(status.NotFound, "setup key not found")
} }
return key, nil return key, nil
@ -458,14 +457,14 @@ func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, er
if err == nil { if err == nil {
log.Warnf("an account with ID already exists, retrying...") log.Warnf("an account with ID already exists, retrying...")
continue continue
} else if statusErr.Code() == codes.NotFound { } else if statusErr.Type() == status.NotFound {
return newAccountWithId(accountId, userID, domain), nil return newAccountWithId(accountId, userID, domain), nil
} else { } else {
return nil, err return nil, err
} }
} }
return nil, status.Errorf(codes.Internal, "error while creating new account") return nil, status.Errorf(status.Internal, "error while creating new account")
} }
func (am *DefaultAccountManager) warmupIDPCache() error { func (am *DefaultAccountManager) warmupIDPCache() error {
@ -492,7 +491,7 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID,
} else if userID != "" { } else if userID != "" {
account, err := am.GetOrCreateAccountByUser(userID, domain) account, err := am.GetOrCreateAccountByUser(userID, domain)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userID) return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID)
} }
err = am.addAccountIDToIDPAppMeta(userID, account) err = am.addAccountIDToIDPAppMeta(userID, account)
if err != nil { if err != nil {
@ -501,7 +500,7 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID,
return account, nil return account, nil
} }
return nil, status.Errorf(codes.NotFound, "no valid user or account Id provided") return nil, status.Errorf(status.NotFound, "no valid user or account Id provided")
} }
func isNil(i idp.Manager) bool { func isNil(i idp.Manager) bool {
@ -531,11 +530,7 @@ func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account
} }
if err != nil { if err != nil {
return status.Errorf( return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
codes.Internal,
"updating user's app metadata failed with: %v",
err,
)
} }
// refresh cache to reflect the update // refresh cache to reflect the update
_, err = am.refreshCache(account.Id) _, err = am.refreshCache(account.Id)
@ -662,11 +657,8 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a
} }
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account // updateAccountDomainAttributes updates the account domain attributes and then, saves the account
func (am *DefaultAccountManager) updateAccountDomainAttributes( func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims,
account *Account, primaryDomain bool) error {
claims jwtclaims.AuthorizationClaims,
primaryDomain bool,
) error {
account.IsDomainPrimaryAccount = primaryDomain account.IsDomainPrimaryAccount = primaryDomain
lowerDomain := strings.ToLower(claims.Domain) lowerDomain := strings.ToLower(claims.Domain)
@ -681,7 +673,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(
err := am.Store.SaveAccount(account) err := am.Store.SaveAccount(account)
if err != nil { if err != nil {
return status.Errorf(codes.Internal, "failed saving updated account") return err
} }
return nil return nil
} }
@ -723,10 +715,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
// otherwise it will create a new account and make it primary account for the domain. // otherwise it will create a new account and make it primary account for the domain.
func (am *DefaultAccountManager) handleNewUserAccount( func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
domainAcc *Account,
claims jwtclaims.AuthorizationClaims,
) (*Account, error) {
var ( var (
account *Account account *Account
err error err error
@ -738,7 +727,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(
account.Users[claims.UserId] = NewRegularUser(claims.UserId) account.Users[claims.UserId] = NewRegularUser(claims.UserId)
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(account)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed saving updated account") return nil, err
} }
} else { } else {
account, err = am.newAccount(claims.UserId, lowerDomain) account, err = am.newAccount(claims.UserId, lowerDomain)
@ -773,7 +762,7 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
} }
if user == nil { if user == nil {
return status.Errorf(codes.NotFound, "user %s not found in the IdP", userID) return status.Errorf(status.NotFound, "user %s not found in the IdP", userID)
} }
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite { if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
@ -794,7 +783,7 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
} }
// GetAccountFromToken returns an account associated with this token // GetAccountFromToken returns an account associated with this token
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) { func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
if am.singleAccountMode && am.singleAccountModeDomain != "" { if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations. // This section is mostly related to self-hosted installations.
@ -806,15 +795,21 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
account, err := am.getAccountWithAuthorizationClaims(claims) account, err := am.getAccountWithAuthorizationClaims(claims)
if err != nil { if err != nil {
return nil, err return nil, nil, err
}
user := account.Users[claims.UserId]
if user == nil {
// this is not really possible because we got an account by user ID
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
} }
err = am.redeemInvite(account, claims.UserId) err = am.redeemInvite(account, claims.UserId)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
return account, nil return account, user, nil
} }
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims. // getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
@ -857,9 +852,12 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
// We checked if the domain has a primary account already // We checked if the domain has a primary account already
domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain) domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain)
accStatus, _ := status.FromError(err) if err != nil {
if accStatus.Code() != codes.OK && accStatus.Code() != codes.NotFound { // if NotFound we are good to continue, otherwise return error
return nil, err e, ok := status.FromError(err)
if !ok || e.Type() != status.NotFound {
return nil, err
}
} }
account, err := am.Store.GetAccountByUser(claims.UserId) account, err := am.Store.GetAccountByUser(claims.UserId)
@ -869,7 +867,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
return nil, err return nil, err
} }
return account, nil return account, nil
} else if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
return am.handleNewUserAccount(domainAccount, claims) return am.handleNewUserAccount(domainAccount, claims)
} else { } else {
// other error // other error
@ -891,7 +889,7 @@ func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error)
var res bool var res bool
_, err := am.Store.GetAccount(accountID) _, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
res = false res = false
return &res, nil return &res, nil
} else { } else {

View File

@ -314,7 +314,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id testCase.inputClaims.AccountId = initAccount.Id
} }
account, err := manager.GetAccountFromToken(testCase.inputClaims) account, _, err := manager.GetAccountFromToken(testCase.inputClaims)
require.NoError(t, err, "support function failed") require.NoError(t, err, "support function failed")
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers) verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy) verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)

View File

@ -1,61 +0,0 @@
package server
import (
"fmt"
)
const (
// UserAlreadyExists indicates that user already exists
UserAlreadyExists ErrorType = iota
// AccountNotFound indicates that specified account hasn't been found
AccountNotFound
// PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled
PreconditionFailed
// UserNotFound indicates that user wasn't found in the system (or under a given Account)
UserNotFound
// PermissionDenied indicates that user has no permissions to view data
PermissionDenied
// SetupKeyNotFound indicates that the setup key wasn't found in the system (or under a given Account)
SetupKeyNotFound
)
// ErrorType is a type of the Error
type ErrorType int32
// Error is an internal error
type Error struct {
errorType ErrorType
message string
}
// Type returns the Type of the error
func (e *Error) Type() ErrorType {
return e.errorType
}
// Error is an error string
func (e *Error) Error() string {
return e.message
}
// Errorf returns Error(errorType, fmt.Sprintf(format, a...)).
func Errorf(errorType ErrorType, format string, a ...interface{}) error {
return &Error{
errorType: errorType,
message: fmt.Sprintf(format, a...),
}
}
// FromError returns Error, true if the provided error is of type of Error. nil, false otherwise
func FromError(err error) (s *Error, ok bool) {
if err == nil {
return nil, true
}
if e, ok := err.(*Error); ok {
return e, true
}
return nil, false
}

View File

@ -1,6 +1,7 @@
package server package server
import ( import (
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"os" "os"
"path/filepath" "path/filepath"
@ -8,9 +9,6 @@ import (
"sync" "sync"
"time" "time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@ -192,10 +190,7 @@ func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
accountID, accountIDFound := s.PrivateDomain2AccountID[strings.ToLower(domain)] accountID, accountIDFound := s.PrivateDomain2AccountID[strings.ToLower(domain)]
if !accountIDFound { if !accountIDFound {
return nil, status.Errorf( return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
codes.NotFound,
"account not found: provided domain is not registered or is not private",
)
} }
account, err := s.getAccount(accountID) account, err := s.getAccount(accountID)
@ -213,7 +208,7 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
accountID, accountIDFound := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] accountID, accountIDFound := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !accountIDFound { if !accountIDFound {
return nil, status.Errorf(codes.NotFound, "account not found: provided setup key doesn't exists") return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
} }
account, err := s.getAccount(accountID) account, err := s.getAccount(accountID)
@ -239,7 +234,7 @@ func (s *FileStore) GetAllAccounts() (all []*Account) {
func (s *FileStore) getAccount(accountID string) (*Account, error) { func (s *FileStore) getAccount(accountID string) (*Account, error) {
account, accountFound := s.Accounts[accountID] account, accountFound := s.Accounts[accountID]
if !accountFound { if !accountFound {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, status.Errorf(status.NotFound, "account not found")
} }
return account, nil return account, nil
@ -265,7 +260,7 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
accountID, accountIDFound := s.UserID2AccountID[userID] accountID, accountIDFound := s.UserID2AccountID[userID]
if !accountIDFound { if !accountIDFound {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, status.Errorf(status.NotFound, "account not found")
} }
account, err := s.getAccount(accountID) account, err := s.getAccount(accountID)
@ -283,7 +278,7 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
accountID, accountIDFound := s.PeerKeyID2AccountID[peerKey] accountID, accountIDFound := s.PeerKeyID2AccountID[peerKey]
if !accountIDFound { if !accountIDFound {
return nil, status.Errorf(codes.NotFound, "Provided peer key doesn't exists %s", peerKey) return nil, status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
} }
account, err := s.getAccount(accountID) account, err := s.getAccount(accountID)
@ -322,7 +317,7 @@ func (s *FileStore) SavePeerStatus(accountID, peerKey string, peerStatus PeerSta
peer := account.Peers[peerKey] peer := account.Peers[peerKey]
if peer == nil { if peer == nil {
return status.Errorf(codes.NotFound, "peer %s not found", peerKey) return status.Errorf(status.NotFound, "peer %s not found", peerKey)
} }
peer.Status = &peerStatus peer.Status = &peerStatus

View File

@ -1,9 +1,6 @@
package server package server
import ( import "github.com/netbirdio/netbird/management/server/status"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// Group of the peers for ACL // Group of the peers for ACL
type Group struct { type Group struct {
@ -53,7 +50,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
group, ok := account.Groups[groupID] group, ok := account.Groups[groupID]
@ -61,7 +58,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er
return group, nil return group, nil
} }
return nil, status.Errorf(codes.NotFound, "group with ID %s not found", groupID) return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
} }
// SaveGroup object of the peers // SaveGroup object of the peers
@ -72,7 +69,7 @@ func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return status.Errorf(codes.NotFound, "account not found") return err
} }
account.Groups[group.ID] = group account.Groups[group.ID] = group
@ -94,12 +91,12 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string,
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
groupToUpdate, ok := account.Groups[groupID] groupToUpdate, ok := account.Groups[groupID]
if !ok { if !ok {
return nil, status.Errorf(codes.NotFound, "group %s no longer exists", groupID) return nil, status.Errorf(status.NotFound, "group with ID %s no longer exists", groupID)
} }
group := groupToUpdate.Copy() group := groupToUpdate.Copy()
@ -130,7 +127,7 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string,
err = am.updateAccountPeers(account) err = am.updateAccountPeers(account)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update account peers") return nil, err
} }
return group, nil return group, nil
@ -144,7 +141,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return status.Errorf(codes.NotFound, "account not found") return err
} }
delete(account.Groups, groupID) delete(account.Groups, groupID)
@ -165,7 +162,7 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error)
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
groups := make([]*Group, 0, len(account.Groups)) groups := make([]*Group, 0, len(account.Groups))
@ -184,12 +181,12 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return status.Errorf(codes.NotFound, "account not found") return err
} }
group, ok := account.Groups[groupID] group, ok := account.Groups[groupID]
if !ok { if !ok {
return status.Errorf(codes.NotFound, "group with ID %s not found", groupID) return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
} }
add := true add := true
@ -219,12 +216,12 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return status.Errorf(codes.NotFound, "account not found") return err
} }
group, ok := account.Groups[groupID] group, ok := account.Groups[groupID]
if !ok { if !ok {
return status.Errorf(codes.NotFound, "group with ID %s not found", groupID) return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
} }
account.Network.IncSerial() account.Network.IncSerial()
@ -232,7 +229,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str
if itemID == peerKey { if itemID == peerKey {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(account); err != nil { if err := am.Store.SaveAccount(account); err != nil {
return status.Errorf(codes.Internal, "can't save account") return err
} }
} }
} }
@ -248,12 +245,12 @@ func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*P
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, status.Errorf(status.NotFound, "account not found")
} }
group, ok := account.Groups[groupID] group, ok := account.Groups[groupID]
if !ok { if !ok {
return nil, status.Errorf(codes.NotFound, "group with ID %s not found", groupID) return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
} }
peers := make([]*Peer, 0, len(account.Groups)) peers := make([]*Peer, 0, len(account.Groups))

View File

@ -14,6 +14,7 @@ import (
"github.com/golang/protobuf/ptypes/timestamp" "github.com/golang/protobuf/ptypes/timestamp"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
internalStatus "github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@ -200,10 +201,6 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err) return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err)
} }
claims := jwtclaims.ExtractClaimsWithToken(token, s.config.HttpConfig.AuthAudience) claims := jwtclaims.ExtractClaimsWithToken(token, s.config.HttpConfig.AuthAudience)
_, err = s.accountManager.GetAccountFromToken(claims)
if err != nil {
return nil, status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
}
userID = claims.UserId userID = claims.UserId
} else { } else {
log.Debugln("using setup key to register peer") log.Debugln("using setup key to register peer")
@ -237,14 +234,12 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
}, },
}) })
if err != nil { if err != nil {
if e, ok := FromError(err); ok { if e, ok := internalStatus.FromError(err); ok {
switch e.Type() { switch e.Type() {
case PreconditionFailed: case internalStatus.PreconditionFailed:
return nil, status.Errorf(codes.FailedPrecondition, e.message) return nil, status.Errorf(codes.FailedPrecondition, e.Message)
case AccountNotFound: case internalStatus.NotFound:
case SetupKeyNotFound: return nil, status.Errorf(codes.NotFound, e.Message)
case UserNotFound:
return nil, status.Errorf(codes.NotFound, e.message)
default: default:
} }
} }
@ -301,7 +296,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
peer, err := s.accountManager.GetPeer(peerKey.String()) peer, err := s.accountManager.GetPeer(peerKey.String())
if err != nil { if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.NotFound { if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
// peer doesn't exist -> check if setup key was provided // peer doesn't exist -> check if setup key was provided
if loginReq.GetJwtToken() == "" && loginReq.GetSetupKey() == "" { if loginReq.GetJwtToken() == "" && loginReq.GetSetupKey() == "" {
// absent setup key or jwt -> permission denied // absent setup key or jwt -> permission denied
@ -387,7 +382,6 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
case TCP: case TCP:
return proto.HostConfig_TCP return proto.HostConfig_TCP
default: default:
// mbragin: todo something better?
panic(fmt.Errorf("unexpected config protocol type %v", configProto)) panic(fmt.Errorf("unexpected config protocol type %v", configProto))
} }
} }

View File

@ -2,10 +2,9 @@ package http
import ( import (
"encoding/json" "encoding/json"
"fmt"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"google.golang.org/grpc/codes" "github.com/netbirdio/netbird/management/server/http/util"
"google.golang.org/grpc/status" "github.com/netbirdio/netbird/management/server/status"
"net/http" "net/http"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
@ -33,7 +32,8 @@ func NewGroups(accountManager server.AccountManager, authAudience string) *Group
// GetAllGroupsHandler list for the account // GetAllGroupsHandler list for the account
func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -45,52 +45,54 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
groups = append(groups, toGroupResponse(account, g)) groups = append(groups, toGroupResponse(account, g))
} }
writeJSONObject(w, groups) util.WriteJSONObject(w, groups)
} }
// UpdateGroupHandler handles update to a group identified by a given ID // UpdateGroupHandler handles update to a group identified by a given ID
func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
groupID, ok := vars["id"] groupID, ok := vars["id"]
if !ok { if !ok {
http.Error(w, "group ID field is missing", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
return return
} }
if len(groupID) == 0 { if len(groupID) == 0 {
http.Error(w, "group ID can't be empty", http.StatusUnprocessableEntity) util.WriteError(status.Errorf(status.InvalidArgument, "group ID can't be empty"), w)
return return
} }
_, ok = account.Groups[groupID] _, ok = account.Groups[groupID]
if !ok { if !ok {
http.Error(w, fmt.Sprintf("couldn't find group with ID %s", groupID), http.StatusNotFound) util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
return return
} }
allGroup, err := account.GetGroupAll() allGroup, err := account.GetGroupAll()
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
if allGroup.ID == groupID { if allGroup.ID == groupID {
http.Error(w, "updating group ALL is not allowed", http.StatusMethodNotAllowed) util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
return return
} }
var req api.PutApiGroupsIdJSONRequestBody var req api.PutApiGroupsIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
if *req.Name == "" { if *req.Name == "" {
http.Error(w, "group name shouldn't be empty", http.StatusUnprocessableEntity) util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
return return
} }
@ -102,53 +104,55 @@ func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
if err := h.accountManager.SaveGroup(account.Id, &group); err != nil { if err := h.accountManager.SaveGroup(account.Id, &group); err != nil {
log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err) log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
writeJSONObject(w, toGroupResponse(account, &group)) util.WriteJSONObject(w, toGroupResponse(account, &group))
} }
// PatchGroupHandler handles patch updates to a group identified by a given ID // PatchGroupHandler handles patch updates to a group identified by a given ID
func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
groupID := vars["id"] groupID := vars["id"]
if len(groupID) == 0 { if len(groupID) == 0 {
http.Error(w, "invalid group Id", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return return
} }
_, ok := account.Groups[groupID] _, ok := account.Groups[groupID]
if !ok { if !ok {
http.Error(w, fmt.Sprintf("couldn't find group id %s", groupID), http.StatusNotFound) util.WriteError(status.Errorf(status.NotFound, "couldn't find group ID %s", groupID), w)
return return
} }
allGroup, err := account.GetGroupAll() allGroup, err := account.GetGroupAll()
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
if allGroup.ID == groupID { if allGroup.ID == groupID {
http.Error(w, "updating group ALL is not allowed", http.StatusMethodNotAllowed) util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
return return
} }
var req api.PatchApiGroupsIdJSONRequestBody var req api.PatchApiGroupsIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
if len(req) == 0 { if len(req) == 0 {
http.Error(w, "no patch instruction received", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
return return
} }
@ -158,13 +162,13 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
switch patch.Path { switch patch.Path {
case api.GroupPatchOperationPathName: case api.GroupPatchOperationPathName:
if patch.Op != api.GroupPatchOperationOpReplace { if patch.Op != api.GroupPatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Name field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "name field only accepts replace operation, got %s", patch.Op), w)
return return
} }
if len(patch.Value) == 0 || patch.Value[0] == "" { if len(patch.Value) == 0 || patch.Value[0] == "" {
http.Error(w, "Group name shouldn't be empty", http.StatusUnprocessableEntity) util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
return return
} }
@ -193,53 +197,43 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
Values: peerKeys, Values: peerKeys,
}) })
default: default:
http.Error(w, "invalid operation, \"%s\", for Peers field", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument,
"invalid operation, \"%v\", for Peers field", patch.Op), w)
return return
} }
default: default:
http.Error(w, "invalid patch path", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
return return
} }
} }
group, err := h.accountManager.UpdateGroup(account.Id, groupID, operations) group, err := h.accountManager.UpdateGroup(account.Id, groupID, operations)
if err != nil { if err != nil {
errStatus, ok := status.FromError(err) util.WriteError(err, w)
if ok && errStatus.Code() == codes.Internal {
http.Error(w, errStatus.String(), http.StatusInternalServerError)
return
}
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, errStatus.String(), http.StatusNotFound)
return
}
log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
writeJSONObject(w, toGroupResponse(account, group)) util.WriteJSONObject(w, toGroupResponse(account, group))
} }
// CreateGroupHandler handles group creation request // CreateGroupHandler handles group creation request
func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
var req api.PostApiGroupsJSONRequestBody var req api.PostApiGroupsJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
if req.Name == "" { if req.Name == "" {
http.Error(w, "Group name shouldn't be empty", http.StatusUnprocessableEntity) util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
return return
} }
@ -249,55 +243,57 @@ func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
Peers: peerIPsToKeys(account, req.Peers), Peers: peerIPsToKeys(account, req.Peers),
} }
if err := h.accountManager.SaveGroup(account.Id, &group); err != nil { err = h.accountManager.SaveGroup(account.Id, &group)
log.Errorf("failed creating group \"%s\" under account %s %v", req.Name, account.Id, err) if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
writeJSONObject(w, toGroupResponse(account, &group)) util.WriteJSONObject(w, toGroupResponse(account, &group))
} }
// DeleteGroupHandler handles group deletion request // DeleteGroupHandler handles group deletion request
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
aID := account.Id aID := account.Id
groupID := mux.Vars(r)["id"] groupID := mux.Vars(r)["id"]
if len(groupID) == 0 { if len(groupID) == 0 {
http.Error(w, "invalid group ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return return
} }
allGroup, err := account.GetGroupAll() allGroup, err := account.GetGroupAll()
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
if allGroup.ID == groupID { if allGroup.ID == groupID {
http.Error(w, "deleting group ALL is not allowed", http.StatusMethodNotAllowed) util.WriteError(status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
return return
} }
if err := h.accountManager.DeleteGroup(aID, groupID); err != nil { err = h.accountManager.DeleteGroup(aID, groupID)
log.Errorf("failed delete group %s under account %s %v", groupID, aID, err) if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
writeJSONObject(w, "") util.WriteJSONObject(w, "")
} }
// GetGroupHandler returns a group // GetGroupHandler returns a group
func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
@ -305,19 +301,22 @@ func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
case http.MethodGet: case http.MethodGet:
groupID := mux.Vars(r)["id"] groupID := mux.Vars(r)["id"]
if len(groupID) == 0 { if len(groupID) == 0 {
http.Error(w, "invalid group ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return return
} }
group, err := h.accountManager.GetGroup(account.Id, groupID) group, err := h.accountManager.GetGroup(account.Id, groupID)
if err != nil { if err != nil {
http.Error(w, "group not found", http.StatusNotFound) util.WriteError(err, w)
return return
} }
writeJSONObject(w, toGroupResponse(account, group)) util.WriteJSONObject(w, toGroupResponse(account, group))
default: default:
http.Error(w, "", http.StatusNotFound) if err != nil {
util.WriteError(status.Errorf(status.NotFound, "HTTP method not found"), w)
return
}
} }
} }

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"io" "io"
"net" "net"
"net/http" "net/http"
@ -36,7 +37,7 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
}, },
GetGroupFunc: func(_, groupID string) (*server.Group, error) { GetGroupFunc: func(_, groupID string) (*server.Group, error) {
if groupID != "idofthegroup" { if groupID != "idofthegroup" {
return nil, fmt.Errorf("not found") return nil, status.Errorf(status.NotFound, "not found")
} }
return &server.Group{ return &server.Group{
ID: "idofthegroup", ID: "idofthegroup",
@ -67,7 +68,7 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
} }
return nil, fmt.Errorf("peer not found") return nil, fmt.Errorf("peer not found")
}, },
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",
@ -78,7 +79,7 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
Groups: map[string]*server.Group{ Groups: map[string]*server.Group{
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}}, "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}}, "id-all": {ID: "id-all", Name: "All"}},
}, nil }, user, nil
}, },
}, },
authAudience: "", authAudience: "",
@ -223,7 +224,7 @@ func TestWriteGroup(t *testing.T) {
requestPath: "/api/groups/id-all", requestPath: "/api/groups/id-all",
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte(`{"Name":"super"}`)), []byte(`{"Name":"super"}`)),
expectedStatus: http.StatusMethodNotAllowed, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{ {
@ -244,7 +245,7 @@ func TestWriteGroup(t *testing.T) {
requestPath: "/api/groups/id-existed", requestPath: "/api/groups/id-existed",
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte(`[{"op":"insert","path":"name","value":[""]}]`)), []byte(`[{"op":"insert","path":"name","value":[""]}]`)),
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{ {

View File

@ -1,7 +1,8 @@
package middleware package middleware
import ( import (
"fmt" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"net/http" "net/http"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@ -33,14 +34,15 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
ok, err := a.isUserAdmin(jwtClaims) ok, err := a.isUserAdmin(jwtClaims)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("error get user from JWT: %v", err), http.StatusUnauthorized) util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
return return
} }
if !ok { if !ok {
switch r.Method { switch r.Method {
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut: case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
http.Error(w, "user is not admin", http.StatusForbidden) util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w)
return return
} }
} }

View File

@ -7,12 +7,12 @@ import (
"net/http" "net/http"
) )
//Jwks is a collection of JSONWebKeys obtained from Config.HttpServerConfig.AuthKeysLocation // Jwks is a collection of JSONWebKeys obtained from Config.HttpServerConfig.AuthKeysLocation
type Jwks struct { type Jwks struct {
Keys []JSONWebKeys `json:"keys"` Keys []JSONWebKeys `json:"keys"`
} }
//JSONWebKeys is a representation of a Jason Web Key // JSONWebKeys is a representation of a Jason Web Key
type JSONWebKeys struct { type JSONWebKeys struct {
Kty string `json:"kty"` Kty string `json:"kty"`
Kid string `json:"kid"` Kid string `json:"kid"`
@ -22,7 +22,7 @@ type JSONWebKeys struct {
X5c []string `json:"x5c"` X5c []string `json:"x5c"`
} }
//NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header // NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header
func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) { func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) {
keys, err := getPemKeys(keysLocation) keys, err := getPemKeys(keysLocation)
@ -66,7 +66,6 @@ func getPemKeys(keysLocation string) (*Jwks, error) {
var jwks = &Jwks{} var jwks = &Jwks{}
err = json.NewDecoder(resp.Body).Decode(jwks) err = json.NewDecoder(resp.Body).Decode(jwks)
if err != nil { if err != nil {
return jwks, err return jwks, err
} }

View File

@ -5,6 +5,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"log" "log"
"net/http" "net/http"
"strings" "strings"
@ -57,7 +59,7 @@ type JWTMiddleware struct {
} }
func OnError(w http.ResponseWriter, r *http.Request, err string) { func OnError(w http.ResponseWriter, r *http.Request, err string) {
http.Error(w, err, http.StatusUnauthorized) util.WriteError(status.Errorf(status.Unauthorized, ""), w)
} }
// New constructs a new Secure instance with supplied options. // New constructs a new Secure instance with supplied options.

View File

@ -7,7 +7,9 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"net/http" "net/http"
) )
@ -30,7 +32,8 @@ func NewNameservers(accountManager server.AccountManager, authAudience string) *
// GetAllNameserversHandler returns the list of nameserver groups for the account // GetAllNameserversHandler returns the list of nameserver groups for the account
func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -39,7 +42,7 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re
nsGroups, err := h.accountManager.ListNameServerGroups(account.Id) nsGroups, err := h.accountManager.ListNameServerGroups(account.Id)
if err != nil { if err != nil {
toHTTPError(err, w) util.WriteError(err, w)
return return
} }
@ -48,64 +51,67 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re
apiNameservers = append(apiNameservers, toNameserverGroupResponse(r)) apiNameservers = append(apiNameservers, toNameserverGroupResponse(r))
} }
writeJSONObject(w, apiNameservers) util.WriteJSONObject(w, apiNameservers)
} }
// CreateNameserverGroupHandler handles nameserver group creation request // CreateNameserverGroupHandler handles nameserver group creation request
func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
var req api.PostApiDnsNameserversJSONRequestBody var req api.PostApiDnsNameserversJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
nsList, err := toServerNSList(req.Nameservers) nsList, err := toServerNSList(req.Nameservers)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
return return
} }
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled) nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled)
if err != nil { if err != nil {
toHTTPError(err, w) util.WriteError(err, w)
return return
} }
resp := toNameserverGroupResponse(nsGroup) resp := toNameserverGroupResponse(nsGroup)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
// UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID // UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID
func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
nsGroupID := mux.Vars(r)["id"] nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 { if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return return
} }
var req api.PutApiDnsNameserversIdJSONRequestBody var req api.PutApiDnsNameserversIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
nsList, err := toServerNSList(req.Nameservers) nsList, err := toServerNSList(req.Nameservers)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
return return
} }
@ -122,41 +128,42 @@ func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *htt
err = h.accountManager.SaveNameServerGroup(account.Id, updatedNSGroup) err = h.accountManager.SaveNameServerGroup(account.Id, updatedNSGroup)
if err != nil { if err != nil {
toHTTPError(err, w) util.WriteError(err, w)
return return
} }
resp := toNameserverGroupResponse(updatedNSGroup) resp := toNameserverGroupResponse(updatedNSGroup)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
// PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID // PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID
func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
nsGroupID := mux.Vars(r)["id"] nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 { if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return return
} }
var req api.PatchApiDnsNameserversIdJSONRequestBody var req api.PatchApiDnsNameserversIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
var operations []server.NameServerGroupUpdateOperation var operations []server.NameServerGroupUpdateOperation
for _, patch := range req { for _, patch := range req {
if patch.Op != api.NameserverGroupPatchOperationOpReplace { if patch.Op != api.NameserverGroupPatchOperationOpReplace {
http.Error(w, fmt.Sprintf("nameserver groups only accepts replace operations, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "nameserver groups only accepts replace operations, got %s", patch.Op), w)
return return
} }
switch patch.Path { switch patch.Path {
@ -196,49 +203,50 @@ func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http
Values: patch.Value, Values: patch.Value,
}) })
default: default:
http.Error(w, "invalid patch path", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
return return
} }
} }
updatedNSGroup, err := h.accountManager.UpdateNameServerGroup(account.Id, nsGroupID, operations) updatedNSGroup, err := h.accountManager.UpdateNameServerGroup(account.Id, nsGroupID, operations)
if err != nil { if err != nil {
toHTTPError(err, w) util.WriteError(err, w)
return return
} }
resp := toNameserverGroupResponse(updatedNSGroup) resp := toNameserverGroupResponse(updatedNSGroup)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
// DeleteNameserverGroupHandler handles nameserver group deletion request // DeleteNameserverGroupHandler handles nameserver group deletion request
func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
nsGroupID := mux.Vars(r)["id"] nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 { if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return return
} }
err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID) err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID)
if err != nil { if err != nil {
toHTTPError(err, w) util.WriteError(err, w)
return return
} }
writeJSONObject(w, "") util.WriteJSONObject(w, "")
} }
// GetNameserverGroupHandler handles a nameserver group Get request identified by ID // GetNameserverGroupHandler handles a nameserver group Get request identified by ID
func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -247,19 +255,19 @@ func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.R
nsGroupID := mux.Vars(r)["id"] nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 { if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return return
} }
nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, nsGroupID) nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, nsGroupID)
if err != nil { if err != nil {
toHTTPError(err, w) util.WriteError(err, w)
return return
} }
resp := toNameserverGroupResponse(nsGroup) resp := toNameserverGroupResponse(nsGroup)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }

View File

@ -5,9 +5,8 @@ import (
"encoding/json" "encoding/json"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -62,7 +61,7 @@ func initNameserversTestData() *Nameservers {
if nsGroupID == existingNSGroupID { if nsGroupID == existingNSGroupID {
return baseExistingNSGroup.Copy(), nil return baseExistingNSGroup.Copy(), nil
} }
return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID) return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
}, },
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) { CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
return &nbdns.NameServerGroup{ return &nbdns.NameServerGroup{
@ -83,12 +82,12 @@ func initNameserversTestData() *Nameservers {
if nsGroupToSave.ID == existingNSGroupID { if nsGroupToSave.ID == existingNSGroupID {
return nil return nil
} }
return status.Errorf(codes.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
}, },
UpdateNameServerGroupFunc: func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) { UpdateNameServerGroupFunc: func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
nsGroupToUpdate := baseExistingNSGroup.Copy() nsGroupToUpdate := baseExistingNSGroup.Copy()
if nsGroupID != nsGroupToUpdate.ID { if nsGroupID != nsGroupToUpdate.ID {
return nil, status.Errorf(codes.NotFound, "nameserver group ID %s no longer exists", nsGroupID) return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
} }
for _, operation := range operations { for _, operation := range operations {
switch operation.Type { switch operation.Type {
@ -110,8 +109,8 @@ func initNameserversTestData() *Nameservers {
} }
return nsGroupToUpdate, nil return nsGroupToUpdate, nil
}, },
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingNSAccount, nil return testingNSAccount, nil, nil
}, },
}, },
authAudience: "", authAudience: "",
@ -181,7 +180,7 @@ func TestNameserversHandlers(t *testing.T) {
requestPath: "/api/dns/nameservers", requestPath: "/api/dns/nameservers",
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")), []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{ {
@ -223,7 +222,7 @@ func TestNameserversHandlers(t *testing.T) {
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID, requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")), []byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{ {

View File

@ -6,8 +6,9 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/status"
"net/http" "net/http"
) )
@ -28,50 +29,47 @@ func NewPeers(accountManager server.AccountManager, authAudience string) *Peers
func (h *Peers) updatePeer(account *server.Account, peer *server.Peer, w http.ResponseWriter, r *http.Request) { func (h *Peers) updatePeer(account *server.Account, peer *server.Peer, w http.ResponseWriter, r *http.Request) {
req := &api.PutApiPeersIdJSONBody{} req := &api.PutApiPeersIdJSONBody{}
peerIp := peer.IP
err := json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
update := &server.Peer{Key: peer.Key, SSHEnabled: req.SshEnabled, Name: req.Name} update := &server.Peer{Key: peer.Key, SSHEnabled: req.SshEnabled, Name: req.Name}
peer, err = h.accountManager.UpdatePeer(account.Id, update) peer, err = h.accountManager.UpdatePeer(account.Id, update)
if err != nil { if err != nil {
log.Errorf("failed updating peer %s under account %s %v", peerIp, account.Id, err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
writeJSONObject(w, toPeerResponse(peer, account)) util.WriteJSONObject(w, toPeerResponse(peer, account))
} }
func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseWriter, r *http.Request) { func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseWriter, r *http.Request) {
_, err := h.accountManager.DeletePeer(accountId, peer.Key) _, err := h.accountManager.DeletePeer(accountId, peer.Key)
if err != nil { if err != nil {
log.Errorf("failed deleteing peer %s, %v", peer.IP, err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
writeJSONObject(w, "") util.WriteJSONObject(w, "")
} }
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) { func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
peerId := vars["id"] //effectively peer IP address peerId := vars["id"] //effectively peer IP address
if len(peerId) == 0 { if len(peerId) == 0 {
http.Error(w, "invalid peer Id", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return return
} }
peer, err := h.accountManager.GetPeerByIP(account.Id, peerId) peer, err := h.accountManager.GetPeerByIP(account.Id, peerId)
if err != nil { if err != nil {
http.Error(w, "peer not found", http.StatusNotFound) util.WriteError(err, w)
return return
} }
@ -83,11 +81,11 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
h.updatePeer(account, peer, w, r) h.updatePeer(account, peer, w, r)
return return
case http.MethodGet: case http.MethodGet:
writeJSONObject(w, toPeerResponse(peer, account)) util.WriteJSONObject(w, toPeerResponse(peer, account))
return return
default: default:
http.Error(w, "", http.StatusNotFound) util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
} }
} }
@ -95,15 +93,16 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) { func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
switch r.Method { switch r.Method {
case http.MethodGet: case http.MethodGet:
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
peers, err := h.accountManager.GetPeers(account.Id, user.Id) peers, err := h.accountManager.GetPeers(account.Id, user.Id)
if err != nil { if err != nil {
util.WriteError(err, w)
return return
} }
@ -111,10 +110,10 @@ func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
for _, peer := range peers { for _, peer := range peers {
respBody = append(respBody, toPeerResponse(peer, account)) respBody = append(respBody, toPeerResponse(peer, account))
} }
writeJSONObject(w, respBody) util.WriteJSONObject(w, respBody)
return return
default: default:
http.Error(w, "", http.StatusNotFound) util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
} }
} }

View File

@ -22,7 +22,8 @@ func initTestMetaData(peers ...*server.Peer) *Peers {
GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) { GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) {
return peers, nil return peers, nil
}, },
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",
@ -30,9 +31,9 @@ func initTestMetaData(peers ...*server.Peer) *Peers {
"test_peer": peers[0], "test_peer": peers[0],
}, },
Users: map[string]*server.User{ Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"), "test_user": user,
}, },
}, nil }, user, nil
}, },
}, },
authAudience: "", authAudience: "",

View File

@ -2,15 +2,13 @@ package http
import ( import (
"encoding/json" "encoding/json"
"fmt"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http" "net/http"
"unicode/utf8" "unicode/utf8"
) )
@ -33,25 +31,16 @@ func NewRoutes(accountManager server.AccountManager, authAudience string) *Route
// GetAllRoutesHandler returns the list of routes for the account // GetAllRoutesHandler returns the list of routes for the account
func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
routes, err := h.accountManager.ListRoutes(account.Id, user.Id) routes, err := h.accountManager.ListRoutes(account.Id, user.Id)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
if e, ok := server.FromError(err); ok {
switch e.Type() {
case server.PermissionDenied:
http.Error(w, e.Error(), http.StatusForbidden)
return
default:
}
}
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
apiRoutes := make([]*api.Route, 0) apiRoutes := make([]*api.Route, 0)
@ -59,20 +48,22 @@ func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
apiRoutes = append(apiRoutes, toRouteResponse(account, r)) apiRoutes = append(apiRoutes, toRouteResponse(account, r))
} }
writeJSONObject(w, apiRoutes) util.WriteJSONObject(w, apiRoutes)
} }
// CreateRouteHandler handles route creation request // CreateRouteHandler handles route creation request
func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
var req api.PostApiRoutesJSONRequestBody var req api.PostApiRoutesJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
@ -80,8 +71,7 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
if req.Peer != "" { if req.Peer != "" {
peer, err := h.accountManager.GetPeerByIP(account.Id, req.Peer) peer, err := h.accountManager.GetPeerByIP(account.Id, req.Peer)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusUnprocessableEntity)
return return
} }
peerKey = peer.Key peerKey = peer.Key
@ -89,57 +79,60 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
_, newPrefix, err := route.ParseNetwork(req.Network) _, newPrefix, err := route.ParseNetwork(req.Network)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("couldn't parse update prefix %s", req.Network), http.StatusBadRequest) util.WriteError(err, w)
return return
} }
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" { if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
http.Error(w, fmt.Sprintf("identifier should be between 1 and %d", route.MaxNetIDChar), http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d",
route.MaxNetIDChar), w)
return return
} }
newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix.String(), peerKey, req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Enabled) newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix.String(), peerKey, req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Enabled)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
resp := toRouteResponse(account, newRoute) resp := toRouteResponse(account, newRoute)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
// UpdateRouteHandler handles update to a route identified by a given ID // UpdateRouteHandler handles update to a route identified by a given ID
func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
routeID := vars["id"] routeID := vars["id"]
if len(routeID) == 0 { if len(routeID) == 0 {
http.Error(w, "invalid route Id", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return return
} }
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id) _, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("couldn't find route for ID %s", routeID), http.StatusNotFound) util.WriteError(err, w)
return return
} }
var req api.PutApiRoutesIdJSONRequestBody var req api.PutApiRoutesIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
prefixType, newPrefix, err := route.ParseNetwork(req.Network) prefixType, newPrefix, err := route.ParseNetwork(req.Network)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("couldn't parse update prefix %s for route ID %s", req.Network, routeID), http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "couldn't parse update prefix %s for route ID %s",
req.Network, routeID), w)
return return
} }
@ -147,15 +140,15 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
if req.Peer != "" { if req.Peer != "" {
peer, err := h.accountManager.GetPeerByIP(account.Id, req.Peer) peer, err := h.accountManager.GetPeerByIP(account.Id, req.Peer)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusUnprocessableEntity)
return return
} }
peerKey = peer.Key peerKey = peer.Key
} }
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" { if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
http.Error(w, fmt.Sprintf("identifier should be between 1 and %d", route.MaxNetIDChar), http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument,
"identifier should be between 1 and %d", route.MaxNetIDChar), w)
return return
} }
@ -173,46 +166,46 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
err = h.accountManager.SaveRoute(account.Id, newRoute) err = h.accountManager.SaveRoute(account.Id, newRoute)
if err != nil { if err != nil {
log.Errorf("failed updating route \"%s\" under account %s %v", routeID, account.Id, err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
resp := toRouteResponse(account, newRoute) resp := toRouteResponse(account, newRoute)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
// PatchRouteHandler handles patch updates to a route identified by a given ID // PatchRouteHandler handles patch updates to a route identified by a given ID
func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
routeID := vars["id"] routeID := vars["id"]
if len(routeID) == 0 { if len(routeID) == 0 {
http.Error(w, "invalid route ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return return
} }
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id) _, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Error(w, fmt.Sprintf("couldn't find route ID %s", routeID), http.StatusNotFound)
return return
} }
var req api.PatchApiRoutesIdJSONRequestBody var req api.PatchApiRoutesIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
if len(req) == 0 { if len(req) == 0 {
http.Error(w, "no patch instruction received", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
return return
} }
@ -222,8 +215,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
switch patch.Path { switch patch.Path {
case api.RoutePatchOperationPathNetwork: case api.RoutePatchOperationPathNetwork:
if patch.Op != api.RoutePatchOperationOpReplace { if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Network field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "network field only accepts replace operation, got %s", patch.Op), w)
return return
} }
operations = append(operations, server.RouteUpdateOperation{ operations = append(operations, server.RouteUpdateOperation{
@ -232,8 +225,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
}) })
case api.RoutePatchOperationPathDescription: case api.RoutePatchOperationPathDescription:
if patch.Op != api.RoutePatchOperationOpReplace { if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Description field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "description field only accepts replace operation, got %s", patch.Op), w)
return return
} }
operations = append(operations, server.RouteUpdateOperation{ operations = append(operations, server.RouteUpdateOperation{
@ -242,8 +235,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
}) })
case api.RoutePatchOperationPathNetworkId: case api.RoutePatchOperationPathNetworkId:
if patch.Op != api.RoutePatchOperationOpReplace { if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Network Identifier field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "network Identifier field only accepts replace operation, got %s", patch.Op), w)
return return
} }
operations = append(operations, server.RouteUpdateOperation{ operations = append(operations, server.RouteUpdateOperation{
@ -252,21 +245,20 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
}) })
case api.RoutePatchOperationPathPeer: case api.RoutePatchOperationPathPeer:
if patch.Op != api.RoutePatchOperationOpReplace { if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Peer field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "peer field only accepts replace operation, got %s", patch.Op), w)
return return
} }
if len(patch.Value) > 1 { if len(patch.Value) > 1 {
http.Error(w, fmt.Sprintf("Value field only accepts 1 value, got %d", len(patch.Value)), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "value field only accepts 1 value, got %d", len(patch.Value)), w)
return return
} }
peerValue := patch.Value peerValue := patch.Value
if patch.Value[0] != "" { if patch.Value[0] != "" {
peer, err := h.accountManager.GetPeerByIP(account.Id, patch.Value[0]) peer, err := h.accountManager.GetPeerByIP(account.Id, patch.Value[0])
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusUnprocessableEntity)
return return
} }
peerValue = []string{peer.Key} peerValue = []string{peer.Key}
@ -277,8 +269,9 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
}) })
case api.RoutePatchOperationPathMetric: case api.RoutePatchOperationPathMetric:
if patch.Op != api.RoutePatchOperationOpReplace { if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Metric field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "metric field only accepts replace operation, got %s", patch.Op), w)
return return
} }
operations = append(operations, server.RouteUpdateOperation{ operations = append(operations, server.RouteUpdateOperation{
@ -287,8 +280,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
}) })
case api.RoutePatchOperationPathMasquerade: case api.RoutePatchOperationPathMasquerade:
if patch.Op != api.RoutePatchOperationOpReplace { if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Masquerade field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "masquerade field only accepts replace operation, got %s", patch.Op), w)
return return
} }
operations = append(operations, server.RouteUpdateOperation{ operations = append(operations, server.RouteUpdateOperation{
@ -297,8 +290,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
}) })
case api.RoutePatchOperationPathEnabled: case api.RoutePatchOperationPathEnabled:
if patch.Op != api.RoutePatchOperationOpReplace { if patch.Op != api.RoutePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Enabled field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "enabled field only accepts replace operation, got %s", patch.Op), w)
return return
} }
operations = append(operations, server.RouteUpdateOperation{ operations = append(operations, server.RouteUpdateOperation{
@ -306,90 +299,68 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
Values: patch.Value, Values: patch.Value,
}) })
default: default:
http.Error(w, "invalid patch path", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
return return
} }
} }
route, err := h.accountManager.UpdateRoute(account.Id, routeID, operations) route, err := h.accountManager.UpdateRoute(account.Id, routeID, operations)
if err != nil { if err != nil {
errStatus, ok := status.FromError(err) util.WriteError(err, w)
if ok && errStatus.Code() == codes.Internal {
http.Error(w, errStatus.String(), http.StatusInternalServerError)
return
}
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, errStatus.String(), http.StatusNotFound)
return
}
if ok && errStatus.Code() == codes.InvalidArgument {
http.Error(w, errStatus.String(), http.StatusBadRequest)
return
}
log.Errorf("failed updating route %s under account %s %v", routeID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
resp := toRouteResponse(account, route) resp := toRouteResponse(account, route)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
// DeleteRouteHandler handles route deletion request // DeleteRouteHandler handles route deletion request
func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
routeID := mux.Vars(r)["id"] routeID := mux.Vars(r)["id"]
if len(routeID) == 0 { if len(routeID) == 0 {
http.Error(w, "invalid route ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return return
} }
err = h.accountManager.DeleteRoute(account.Id, routeID) err = h.accountManager.DeleteRoute(account.Id, routeID)
if err != nil { if err != nil {
errStatus, ok := status.FromError(err) util.WriteError(err, w)
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, fmt.Sprintf("route %s not found under account %s", routeID, account.Id), http.StatusNotFound)
return
}
log.Errorf("failed delete route %s under account %s %v", routeID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
writeJSONObject(w, "") util.WriteJSONObject(w, "")
} }
// GetRouteHandler handles a route Get request identified by ID // GetRouteHandler handles a route Get request identified by ID
func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) { func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
routeID := mux.Vars(r)["id"] routeID := mux.Vars(r)["id"]
if len(routeID) == 0 { if len(routeID) == 0 {
http.Error(w, "invalid route ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return return
} }
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id) foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id)
if err != nil { if err != nil {
http.Error(w, "route not found", http.StatusNotFound) util.WriteError(status.Errorf(status.NotFound, "route not found"), w)
return return
} }
writeJSONObject(w, toRouteResponse(account, foundRoute)) util.WriteJSONObject(w, toRouteResponse(account, foundRoute))
} }
func toRouteResponse(account *server.Account, serverRoute *route.Route) *api.Route { func toRouteResponse(account *server.Account, serverRoute *route.Route) *api.Route {

View File

@ -5,9 +5,8 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -63,7 +62,7 @@ func initRoutesTestData() *Routes {
if routeID == existingRouteID { if routeID == existingRouteID {
return baseExistingRoute, nil return baseExistingRoute, nil
} }
return nil, status.Errorf(codes.NotFound, "route with ID %s not found", routeID) return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
}, },
CreateRouteFunc: func(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) { CreateRouteFunc: func(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) {
networkType, p, _ := route.ParseNetwork(network) networkType, p, _ := route.ParseNetwork(network)
@ -83,13 +82,13 @@ func initRoutesTestData() *Routes {
}, },
DeleteRouteFunc: func(_ string, peerIP string) error { DeleteRouteFunc: func(_ string, peerIP string) error {
if peerIP != existingRouteID { if peerIP != existingRouteID {
return status.Errorf(codes.NotFound, "Peer with ID %s not found", peerIP) return status.Errorf(status.NotFound, "Peer with ID %s not found", peerIP)
} }
return nil return nil
}, },
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) { GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
if peerIP != existingPeerID { if peerIP != existingPeerID {
return nil, status.Errorf(codes.NotFound, "Peer with ID %s not found", peerIP) return nil, status.Errorf(status.NotFound, "Peer with ID %s not found", peerIP)
} }
return &server.Peer{ return &server.Peer{
Key: existingPeerKey, Key: existingPeerKey,
@ -99,7 +98,7 @@ func initRoutesTestData() *Routes {
UpdateRouteFunc: func(_ string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) { UpdateRouteFunc: func(_ string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) {
routeToUpdate := baseExistingRoute routeToUpdate := baseExistingRoute
if routeID != routeToUpdate.ID { if routeID != routeToUpdate.ID {
return nil, status.Errorf(codes.NotFound, "route %s no longer exists", routeID) return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID)
} }
for _, operation := range operations { for _, operation := range operations {
switch operation.Type { switch operation.Type {
@ -123,8 +122,8 @@ func initRoutesTestData() *Routes {
} }
return routeToUpdate, nil return routeToUpdate, nil
}, },
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingAccount, nil return testingAccount, testingAccount.Users["test_user"], nil
}, },
}, },
authAudience: "", authAudience: "",
@ -201,15 +200,15 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPost, requestType: http.MethodPost,
requestPath: "/api/routes", requestPath: "/api/routes",
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", notFoundPeerID)), requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", notFoundPeerID)),
expectedStatus: http.StatusUnprocessableEntity, expectedStatus: http.StatusNotFound,
expectedBody: false, expectedBody: false,
}, },
{ {
name: "POST Not Invalid Network Identifier", name: "POST Invalid Network Identifier",
requestType: http.MethodPost, requestType: http.MethodPost,
requestPath: "/api/routes", requestPath: "/api/routes",
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"12345678901234567890qwertyuiopqwertyuiop1\",\"Peer\":\"%s\"}", existingPeerID)), requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"12345678901234567890qwertyuiopqwertyuiop1\",\"Peer\":\"%s\"}", existingPeerID)),
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{ {
@ -217,7 +216,7 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPost, requestType: http.MethodPost,
requestPath: "/api/routes", requestPath: "/api/routes",
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/34\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", existingPeerID)), requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/34\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", existingPeerID)),
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{ {
@ -251,7 +250,7 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID, requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", notFoundPeerID)), requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", notFoundPeerID)),
expectedStatus: http.StatusUnprocessableEntity, expectedStatus: http.StatusNotFound,
expectedBody: false, expectedBody: false,
}, },
{ {
@ -259,7 +258,7 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID, requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"12345678901234567890qwertyuiopqwertyuiop1\",\"Peer\":\"%s\"}", existingPeerID)), requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"12345678901234567890qwertyuiopqwertyuiop1\",\"Peer\":\"%s\"}", existingPeerID)),
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{ {
@ -267,7 +266,7 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/routes/" + existingRouteID, requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/34\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", existingPeerID)), requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/34\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", existingPeerID)),
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{ {
@ -312,7 +311,7 @@ func TestRoutesHandlers(t *testing.T) {
requestType: http.MethodPatch, requestType: http.MethodPatch,
requestPath: "/api/routes/" + existingRouteID, requestPath: "/api/routes/" + existingRouteID,
requestBody: bytes.NewBufferString(fmt.Sprintf("[{\"op\":\"replace\",\"path\":\"peer\",\"value\":[\"%s\"]}]", notFoundPeerID)), requestBody: bytes.NewBufferString(fmt.Sprintf("[{\"op\":\"replace\",\"path\":\"peer\",\"value\":[\"%s\"]}]", notFoundPeerID)),
expectedStatus: http.StatusUnprocessableEntity, expectedStatus: http.StatusNotFound,
expectedBody: false, expectedBody: false,
}, },
{ {

View File

@ -2,15 +2,13 @@ package http
import ( import (
"encoding/json" "encoding/json"
"fmt"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http" "net/http"
) )
@ -31,25 +29,16 @@ func NewRules(accountManager server.AccountManager, authAudience string) *Rules
// GetAllRulesHandler list for the account // GetAllRulesHandler list for the account
func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
accountRules, err := h.accountManager.ListRules(account.Id, user.Id) accountRules, err := h.accountManager.ListRules(account.Id, user.Id)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
if e, ok := server.FromError(err); ok {
switch e.Type() {
case server.PermissionDenied:
http.Error(w, e.Error(), http.StatusForbidden)
return
default:
}
}
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
rules := []*api.Rule{} rules := []*api.Rule{}
@ -57,38 +46,39 @@ func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
rules = append(rules, toRuleResponse(account, r)) rules = append(rules, toRuleResponse(account, r))
} }
writeJSONObject(w, rules) util.WriteJSONObject(w, rules)
} }
// UpdateRuleHandler handles update to a rule identified by a given ID // UpdateRuleHandler handles update to a rule identified by a given ID
func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
ruleID := vars["id"] ruleID := vars["id"]
if len(ruleID) == 0 { if len(ruleID) == 0 {
http.Error(w, "invalid rule Id", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return return
} }
_, ok := account.Rules[ruleID] _, ok := account.Rules[ruleID]
if !ok { if !ok {
http.Error(w, fmt.Sprintf("couldn't find rule id %s", ruleID), http.StatusNotFound) util.WriteError(status.Errorf(status.NotFound, "couldn't find rule id %s", ruleID), w)
return return
} }
var req api.PutApiRulesIdJSONRequestBody var req api.PutApiRulesIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
return util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
} }
if req.Name == "" { if req.Name == "" {
http.Error(w, "Rule name shouldn't be empty", http.StatusUnprocessableEntity) util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w)
return return
} }
@ -115,50 +105,52 @@ func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
case server.TrafficFlowBidirectString: case server.TrafficFlowBidirectString:
rule.Flow = server.TrafficFlowBidirect rule.Flow = server.TrafficFlowBidirect
default: default:
http.Error(w, "unknown flow type", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "unknown flow type"), w)
return return
} }
if err := h.accountManager.SaveRule(account.Id, &rule); err != nil { err = h.accountManager.SaveRule(account.Id, &rule)
log.Errorf("failed updating rule \"%s\" under account %s %v", ruleID, account.Id, err) if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
resp := toRuleResponse(account, &rule) resp := toRuleResponse(account, &rule)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
// PatchRuleHandler handles patch updates to a rule identified by a given ID // PatchRuleHandler handles patch updates to a rule identified by a given ID
func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
ruleID := vars["id"] ruleID := vars["id"]
if len(ruleID) == 0 { if len(ruleID) == 0 {
http.Error(w, "invalid rule Id", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return return
} }
_, ok := account.Rules[ruleID] _, ok := account.Rules[ruleID]
if !ok { if !ok {
http.Error(w, fmt.Sprintf("couldn't find rule id %s", ruleID), http.StatusNotFound) util.WriteError(status.Errorf(status.NotFound, "couldn't find rule ID %s", ruleID), w)
return return
} }
var req api.PatchApiRulesIdJSONRequestBody var req api.PatchApiRulesIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
if len(req) == 0 { if len(req) == 0 {
http.Error(w, "no patch instruction received", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
return return
} }
@ -168,12 +160,12 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
switch patch.Path { switch patch.Path {
case api.RulePatchOperationPathName: case api.RulePatchOperationPathName:
if patch.Op != api.RulePatchOperationOpReplace { if patch.Op != api.RulePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Name field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "name field only accepts replace operation, got %s", patch.Op), w)
return return
} }
if len(patch.Value) == 0 || patch.Value[0] == "" { if len(patch.Value) == 0 || patch.Value[0] == "" {
http.Error(w, "Rule name shouldn't be empty", http.StatusUnprocessableEntity) util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w)
return return
} }
operations = append(operations, server.RuleUpdateOperation{ operations = append(operations, server.RuleUpdateOperation{
@ -182,8 +174,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
}) })
case api.RulePatchOperationPathDescription: case api.RulePatchOperationPathDescription:
if patch.Op != api.RulePatchOperationOpReplace { if patch.Op != api.RulePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Description field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "description field only accepts replace operation, got %s", patch.Op), w)
return return
} }
operations = append(operations, server.RuleUpdateOperation{ operations = append(operations, server.RuleUpdateOperation{
@ -192,8 +184,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
}) })
case api.RulePatchOperationPathFlow: case api.RulePatchOperationPathFlow:
if patch.Op != api.RulePatchOperationOpReplace { if patch.Op != api.RulePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Flow field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "flow field only accepts replace operation, got %s", patch.Op), w)
return return
} }
operations = append(operations, server.RuleUpdateOperation{ operations = append(operations, server.RuleUpdateOperation{
@ -202,8 +194,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
}) })
case api.RulePatchOperationPathDisabled: case api.RulePatchOperationPathDisabled:
if patch.Op != api.RulePatchOperationOpReplace { if patch.Op != api.RulePatchOperationOpReplace {
http.Error(w, fmt.Sprintf("Disabled field only accepts replace operation, got %s", patch.Op), util.WriteError(status.Errorf(status.InvalidArgument,
http.StatusBadRequest) "disabled field only accepts replace operation, got %s", patch.Op), w)
return return
} }
operations = append(operations, server.RuleUpdateOperation{ operations = append(operations, server.RuleUpdateOperation{
@ -228,7 +220,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
Values: patch.Value, Values: patch.Value,
}) })
default: default:
http.Error(w, "invalid operation, \"%s\", for Source field", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument,
"invalid operation \"%s\" on Source field", patch.Op), w)
return return
} }
case api.RulePatchOperationPathDestinations: case api.RulePatchOperationPathDestinations:
@ -249,11 +242,12 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
Values: patch.Value, Values: patch.Value,
}) })
default: default:
http.Error(w, "invalid operation, \"%s\", for Destination field", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument,
"invalid operation \"%s\" on Destination field", patch.Op), w)
return return
} }
default: default:
http.Error(w, "invalid patch path", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
return return
} }
} }
@ -261,48 +255,33 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
rule, err := h.accountManager.UpdateRule(account.Id, ruleID, operations) rule, err := h.accountManager.UpdateRule(account.Id, ruleID, operations)
if err != nil { if err != nil {
errStatus, ok := status.FromError(err) util.WriteError(err, w)
if ok && errStatus.Code() == codes.Internal {
http.Error(w, errStatus.String(), http.StatusInternalServerError)
return
}
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, errStatus.String(), http.StatusNotFound)
return
}
if ok && errStatus.Code() == codes.InvalidArgument {
http.Error(w, errStatus.String(), http.StatusBadRequest)
return
}
log.Errorf("failed updating rule %s under account %s %v", ruleID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
resp := toRuleResponse(account, rule) resp := toRuleResponse(account, rule)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
// CreateRuleHandler handles rule creation request // CreateRuleHandler handles rule creation request
func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
var req api.PostApiRulesJSONRequestBody var req api.PostApiRulesJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { err = json.NewDecoder(r.Body).Decode(&req)
http.Error(w, err.Error(), http.StatusBadRequest) if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
if req.Name == "" { if req.Name == "" {
http.Error(w, "Rule name shouldn't be empty", http.StatusUnprocessableEntity) util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w)
return return
} }
@ -329,50 +308,52 @@ func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
case server.TrafficFlowBidirectString: case server.TrafficFlowBidirectString:
rule.Flow = server.TrafficFlowBidirect rule.Flow = server.TrafficFlowBidirect
default: default:
http.Error(w, "unknown flow type", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "unknown flow type"), w)
return return
} }
if err := h.accountManager.SaveRule(account.Id, &rule); err != nil { err = h.accountManager.SaveRule(account.Id, &rule)
log.Errorf("failed creating rule \"%s\" under account %s %v", req.Name, account.Id, err) if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
resp := toRuleResponse(account, &rule) resp := toRuleResponse(account, &rule)
writeJSONObject(w, &resp) util.WriteJSONObject(w, &resp)
} }
// DeleteRuleHandler handles rule deletion request // DeleteRuleHandler handles rule deletion request
func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
aID := account.Id aID := account.Id
rID := mux.Vars(r)["id"] rID := mux.Vars(r)["id"]
if len(rID) == 0 { if len(rID) == 0 {
http.Error(w, "invalid rule ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return return
} }
if err := h.accountManager.DeleteRule(aID, rID); err != nil { err = h.accountManager.DeleteRule(aID, rID)
log.Errorf("failed delete rule %s under account %s %v", rID, aID, err) if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
writeJSONObject(w, "") util.WriteJSONObject(w, "")
} }
// GetRuleHandler handles a group Get request identified by ID // GetRuleHandler handles a group Get request identified by ID
func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) { func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
@ -380,19 +361,19 @@ func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
case http.MethodGet: case http.MethodGet:
ruleID := mux.Vars(r)["id"] ruleID := mux.Vars(r)["id"]
if len(ruleID) == 0 { if len(ruleID) == 0 {
http.Error(w, "invalid rule ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return return
} }
rule, err := h.accountManager.GetRule(account.Id, ruleID, user.Id) rule, err := h.accountManager.GetRule(account.Id, ruleID, user.Id)
if err != nil { if err != nil {
http.Error(w, "rule not found", http.StatusNotFound) util.WriteError(status.Errorf(status.NotFound, "rule not found"), w)
return return
} }
writeJSONObject(w, toRuleResponse(account, rule)) util.WriteJSONObject(w, toRuleResponse(account, rule))
default: default:
http.Error(w, "", http.StatusNotFound) util.WriteError(status.Errorf(status.NotFound, "method not found"), w)
} }
} }

View File

@ -66,7 +66,8 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
} }
return &rule, nil return &rule, nil
}, },
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",
@ -76,9 +77,9 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
"G": {ID: "G"}, "G": {ID: "G"},
}, },
Users: map[string]*server.User{ Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"), "test_user": user,
}, },
}, nil }, user, nil
}, },
}, },
authAudience: "", authAudience: "",
@ -238,7 +239,7 @@ func TestRulesWriteRule(t *testing.T) {
requestPath: "/api/rules/id-existed", requestPath: "/api/rules/id-existed",
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte(`[{"op":"insert","path":"name","value":[""]}]`)), []byte(`[{"op":"insert","path":"name","value":[""]}]`)),
expectedStatus: http.StatusBadRequest, expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false, expectedBody: false,
}, },
{ {

View File

@ -2,14 +2,12 @@ package http
import ( import (
"encoding/json" "encoding/json"
"fmt"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/status"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http" "net/http"
"time" "time"
) )
@ -31,29 +29,28 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authAudience stri
// CreateSetupKeyHandler is a POST requests that creates a new SetupKey // CreateSetupKeyHandler is a POST requests that creates a new SetupKey
func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
req := &api.PostApiSetupKeysJSONRequestBody{} req := &api.PostApiSetupKeysJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req) err = json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
if req.Name == "" { if req.Name == "" {
http.Error(w, "Setup key name shouldn't be empty", http.StatusUnprocessableEntity) util.WriteError(status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w)
return return
} }
if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable || if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable ||
server.SetupKeyType(req.Type) == server.SetupKeyOneOff) { server.SetupKeyType(req.Type) == server.SetupKeyOneOff) {
util.WriteError(status.Errorf(status.InvalidArgument, "unknown setup key type %s", string(req.Type)), w)
http.Error(w, "unknown setup key type "+string(req.Type), http.StatusBadRequest)
return return
} }
@ -62,17 +59,11 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
if req.AutoGroups == nil { if req.AutoGroups == nil {
req.AutoGroups = []string{} req.AutoGroups = []string{}
} }
// newExpiresIn := time.Duration(req.ExpiresIn) * time.Second
// newKey.ExpiresAt = time.Now().Add(newExpiresIn)
setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn, setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
req.AutoGroups) req.AutoGroups)
if err != nil { if err != nil {
errStatus, ok := status.FromError(err) util.WriteError(err, w)
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, "account not found", http.StatusNotFound)
return
}
http.Error(w, "failed adding setup key", http.StatusInternalServerError)
return return
} }
@ -81,29 +72,23 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
// GetSetupKeyHandler is a GET request to get a SetupKey by ID // GetSetupKeyHandler is a GET request to get a SetupKey by ID
func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
keyID := vars["id"] keyID := vars["id"]
if len(keyID) == 0 { if len(keyID) == 0 {
http.Error(w, "invalid key Id", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
return return
} }
key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID) key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID)
if err != nil { if err != nil {
errStatus, ok := status.FromError(err) util.WriteError(err, w)
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, fmt.Sprintf("setup key %s not found under account %s", keyID, account.Id), http.StatusNotFound)
return
}
log.Errorf("failed getting setup key %s under account %s %v", keyID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
@ -112,34 +97,34 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
// UpdateSetupKeyHandler is a PUT request to update server.SetupKey // UpdateSetupKeyHandler is a PUT request to update server.SetupKey
func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
keyID := vars["id"] keyID := vars["id"]
if len(keyID) == 0 { if len(keyID) == 0 {
http.Error(w, "invalid key Id", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
return return
} }
req := &api.PutApiSetupKeysIdJSONRequestBody{} req := &api.PutApiSetupKeysIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req) err = json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
if req.Name == "" { if req.Name == "" {
http.Error(w, fmt.Sprintf("setup key name field is invalid: %s", req.Name), http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
return return
} }
if req.AutoGroups == nil { if req.AutoGroups == nil {
http.Error(w, fmt.Sprintf("setup key AutoGroups field is invalid: %s", req.AutoGroups), http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
return return
} }
@ -150,16 +135,8 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request
newKey.Id = keyID newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey) newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey)
if err != nil { if err != nil {
if e, ok := status.FromError(err); ok { util.WriteError(err, w)
switch e.Code() {
case codes.NotFound:
http.Error(w, fmt.Sprintf("couldn't find setup key for ID %s", keyID), http.StatusNotFound)
default:
http.Error(w, "failed updating setup key", http.StatusInternalServerError)
}
}
return return
} }
writeSuccess(w, newKey) writeSuccess(w, newKey)
@ -168,25 +145,25 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request
// GetAllSetupKeysHandler is a GET request that returns a list of SetupKey // GetAllSetupKeysHandler is a GET request that returns a list of SetupKey
func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) { func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) {
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id) setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
apiSetupKeys := make([]*api.SetupKey, 0) apiSetupKeys := make([]*api.SetupKey, 0)
for _, key := range setupKeys { for _, key := range setupKeys {
apiSetupKeys = append(apiSetupKeys, toResponseBody(key)) apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
} }
writeJSONObject(w, apiSetupKeys) util.WriteJSONObject(w, apiSetupKeys)
} }
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) { func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
@ -194,7 +171,7 @@ func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(toResponseBody(key)) err := json.NewEncoder(w).Encode(toResponseBody(key))
if err != nil { if err != nil {
http.Error(w, "failed handling request", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
} }

View File

@ -6,9 +6,8 @@ import (
"fmt" "fmt"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -32,7 +31,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
user *server.User) *SetupKeys { user *server.User) *SetupKeys {
return &SetupKeys{ return &SetupKeys{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{ return &server.Account{
Id: testAccountID, Id: testAccountID,
Domain: "hotmail.com", Domain: "hotmail.com",
@ -45,7 +44,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
Groups: map[string]*server.Group{ Groups: map[string]*server.Group{
"group-1": {ID: "group-1", Peers: []string{"A", "B"}}, "group-1": {ID: "group-1", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}}, "id-all": {ID: "id-all", Name: "All"}},
}, nil }, user, nil
}, },
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string) (*server.SetupKey, error) { CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string) (*server.SetupKey, error) {
if keyName == newKey.Name || typ != newKey.Type { if keyName == newKey.Name || typ != newKey.Type {
@ -60,7 +59,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
case newKey.Id: case newKey.Id:
return newKey, nil return newKey, nil
default: default:
return nil, status.Errorf(codes.NotFound, "key %s not found", keyID) return nil, status.Errorf(status.NotFound, "key %s not found", keyID)
} }
}, },
@ -68,7 +67,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
if key.Id == updatedSetupKey.Id { if key.Id == updatedSetupKey.Id {
return updatedSetupKey, nil return updatedSetupKey, nil
} }
return nil, status.Errorf(codes.NotFound, "key %s not found", key.Id) return nil, status.Errorf(status.NotFound, "key %s not found", key.Id)
}, },
ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) { ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) {

View File

@ -2,12 +2,10 @@ package http
import ( import (
"encoding/json" "encoding/json"
"fmt"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/http/util"
"google.golang.org/grpc/codes" "github.com/netbirdio/netbird/management/server/status"
"google.golang.org/grpc/status"
"net/http" "net/http"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
@ -31,33 +29,34 @@ func NewUserHandler(accountManager server.AccountManager, authAudience string) *
// UpdateUser is a PUT requests to update User data // UpdateUser is a PUT requests to update User data
func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut { if r.Method != http.MethodPut {
http.Error(w, "", http.StatusBadRequest) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
} }
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
vars := mux.Vars(r) vars := mux.Vars(r)
userID := vars["id"] userID := vars["id"]
if len(userID) == 0 { if len(userID) == 0 {
http.Error(w, "invalid user ID", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return return
} }
req := &api.PutApiUsersIdJSONRequestBody{} req := &api.PutApiUsersIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req) err = json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
userRole := server.StrRoleToUserRole(req.Role) userRole := server.StrRoleToUserRole(req.Role)
if userRole == server.UserRoleUnknown { if userRole == server.UserRoleUnknown {
http.Error(w, "invalid user role", http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w)
return return
} }
@ -67,40 +66,36 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
AutoGroups: req.AutoGroups, AutoGroups: req.AutoGroups,
}) })
if err != nil { if err != nil {
if e, ok := status.FromError(err); ok { util.WriteError(err, w)
switch e.Code() {
case codes.NotFound:
http.Error(w, fmt.Sprintf("couldn't find a user for ID %s", userID), http.StatusNotFound)
default:
http.Error(w, "failed to update user", http.StatusInternalServerError)
}
}
return return
} }
writeJSONObject(w, toUserResponse(newUser)) util.WriteJSONObject(w, toUserResponse(newUser))
} }
// CreateUserHandler creates a User in the system with a status "invited" (effectively this is a user invite). // CreateUserHandler creates a User in the system with a status "invited" (effectively this is a user invite).
func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request) { func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
http.Error(w, "", http.StatusNotFound) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
} }
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
return
} }
req := &api.PostApiUsersJSONRequestBody{} req := &api.PostApiUsersJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req) err = json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return return
} }
if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown { if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown {
http.Error(w, "unknown user role "+req.Role, http.StatusBadRequest) util.WriteError(status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w)
return return
} }
@ -111,37 +106,30 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request)
AutoGroups: req.AutoGroups, AutoGroups: req.AutoGroups,
}) })
if err != nil { if err != nil {
if e, ok := server.FromError(err); ok { util.WriteError(err, w)
switch e.Type() {
case server.UserAlreadyExists:
http.Error(w, "You can't invite users with an existing NetBird account.", http.StatusPreconditionFailed)
return
default:
}
}
http.Error(w, "failed to invite", http.StatusInternalServerError)
return return
} }
writeJSONObject(w, toUserResponse(newUser)) util.WriteJSONObject(w, toUserResponse(newUser))
} }
// GetUsers returns a list of users of the account this user belongs to. // GetUsers returns a list of users of the account this user belongs to.
// It also gathers additional user data (like email and name) from the IDP manager. // It also gathers additional user data (like email and name) from the IDP manager.
func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) { func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
http.Error(w, "", http.StatusBadRequest) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
} }
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError) util.WriteError(err, w)
return return
} }
data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id) data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id)
if err != nil { if err != nil {
log.Error(err) util.WriteError(err, w)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return return
} }
@ -150,7 +138,7 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
users = append(users, toUserResponse(r)) users = append(users, toUserResponse(r))
} }
writeJSONObject(w, users) util.WriteJSONObject(w, users)
} }
func toUserResponse(user *server.UserInfo) *api.User { func toUserResponse(user *server.UserInfo) *api.User {

View File

@ -16,7 +16,7 @@ import (
func initUsers(user ...*server.User) *UserHandler { func initUsers(user ...*server.User) *UserHandler {
return &UserHandler{ return &UserHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
users := make(map[string]*server.User, 0) users := make(map[string]*server.User, 0)
for _, u := range user { for _, u := range user {
users[u.Id] = u users[u.Id] = u
@ -25,7 +25,7 @@ func initUsers(user ...*server.User) *UserHandler {
Id: "12345", Id: "12345",
Domain: "netbird.io", Domain: "netbird.io",
Users: users, Users: users,
}, nil }, users[claims.UserId], nil
}, },
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) { GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0) users := make([]*server.UserInfo, 0)
@ -66,7 +66,6 @@ func TestGetUsers(t *testing.T) {
expectedResult []*server.User expectedResult []*server.User
}{ }{
{name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users/", expectedStatus: http.StatusOK, expectedResult: users}, {name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users/", expectedStatus: http.StatusOK, expectedResult: users},
{name: "WrongRequestMethod", requestType: http.MethodPost, requestPath: "/api/users/", expectedStatus: http.StatusBadRequest},
} }
for _, tc := range tt { for _, tc := range tt {

View File

@ -1,97 +0,0 @@
package http
import (
"encoding/json"
"errors"
"fmt"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http"
"time"
)
// writeJSONObject simply writes object to the HTTP reponse in JSON format
func writeJSONObject(w http.ResponseWriter, obj interface{}) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
err := json.NewEncoder(w).Encode(obj)
if err != nil {
http.Error(w, "failed handling request", http.StatusInternalServerError)
return
}
}
// Duration is used strictly for JSON requests/responses due to duration marshalling issues
type Duration struct {
time.Duration
}
func (d Duration) MarshalJSON() ([]byte, error) {
return json.Marshal(d.String())
}
func (d *Duration) UnmarshalJSON(b []byte) error {
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
switch value := v.(type) {
case float64:
d.Duration = time.Duration(value)
return nil
case string:
var err error
d.Duration, err = time.ParseDuration(value)
if err != nil {
return err
}
return nil
default:
return errors.New("invalid duration")
}
}
func getJWTAccount(accountManager server.AccountManager,
jwtExtractor jwtclaims.ClaimsExtractor,
authAudience string, r *http.Request) (*server.Account, *server.User, error) {
claims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience)
account, err := accountManager.GetAccountFromToken(claims)
if err != nil {
return nil, nil, fmt.Errorf("failed getting account of a user %s: %v", claims.UserId, err)
}
user := account.Users[claims.UserId]
if user == nil {
// this is not really possible because we got an account by user ID
return nil, nil, fmt.Errorf("user %s not found", claims.UserId)
}
return account, user, nil
}
func toHTTPError(err error, w http.ResponseWriter) {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.Internal {
http.Error(w, errStatus.String(), http.StatusInternalServerError)
return
}
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, errStatus.String(), http.StatusNotFound)
return
}
if ok && errStatus.Code() == codes.InvalidArgument {
http.Error(w, errStatus.String(), http.StatusBadRequest)
return
}
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", errStatus.String())
log.Error(unhandledMSG)
http.Error(w, unhandledMSG, http.StatusInternalServerError)
}

View File

@ -0,0 +1,105 @@
package util
import (
"encoding/json"
"errors"
"fmt"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
"net/http"
"time"
)
// WriteJSONObject simply writes object to the HTTP reponse in JSON format
func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
err := json.NewEncoder(w).Encode(obj)
if err != nil {
WriteError(err, w)
return
}
}
// Duration is used strictly for JSON requests/responses due to duration marshalling issues
type Duration struct {
time.Duration
}
// MarshalJSON marshals the duration
func (d Duration) MarshalJSON() ([]byte, error) {
return json.Marshal(d.String())
}
// UnmarshalJSON unmarshals the duration
func (d *Duration) UnmarshalJSON(b []byte) error {
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
switch value := v.(type) {
case float64:
d.Duration = time.Duration(value)
return nil
case string:
var err error
d.Duration, err = time.ParseDuration(value)
if err != nil {
return err
}
return nil
default:
return errors.New("invalid duration")
}
}
// WriteErrorResponse prepares and writes an error response i nJSON
func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
type errorResponse struct {
Message string `json:"message"`
Code int `json:"code"`
}
w.WriteHeader(httpStatus)
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
err := json.NewEncoder(w).Encode(&errorResponse{
Message: errMsg,
Code: httpStatus,
})
if err != nil {
http.Error(w, "failed handling request", http.StatusInternalServerError)
}
}
// WriteError converts an error to an JSON error response.
// If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise
func WriteError(err error, w http.ResponseWriter) {
errStatus, ok := status.FromError(err)
httpStatus := http.StatusInternalServerError
msg := "internal server error"
if ok {
switch errStatus.Type() {
case status.UserAlreadyExists:
httpStatus = http.StatusConflict
case status.AlreadyExists:
httpStatus = http.StatusConflict
case status.PreconditionFailed:
httpStatus = http.StatusPreconditionFailed
case status.PermissionDenied:
httpStatus = http.StatusForbidden
case status.NotFound:
httpStatus = http.StatusNotFound
case status.Internal:
httpStatus = http.StatusInternalServerError
case status.InvalidArgument:
httpStatus = http.StatusUnprocessableEntity
default:
}
msg = err.Error()
} else {
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error())
log.Error(unhandledMSG)
}
WriteErrorResponse(msg, httpStatus, w)
}

View File

@ -59,7 +59,7 @@ type MockAccountManager struct {
DeleteNameServerGroupFunc func(accountID, nsGroupID string) error DeleteNameServerGroupFunc func(accountID, nsGroupID string) error
ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error) ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(accountID string, key *server.UserInfo) (*server.UserInfo, error) CreateUserFunc func(accountID string, key *server.UserInfo) (*server.UserInfo, error)
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
} }
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
@ -113,7 +113,7 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
} }
// GetAccountByUserOrAccountId mock implementation of GetAccountByUserOrAccountId from server.AccountManager interface // GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface
func (am *MockAccountManager) GetAccountByUserOrAccountID( func (am *MockAccountManager) GetAccountByUserOrAccountID(
userId, accountId, domain string, userId, accountId, domain string,
) (*server.Account, error) { ) (*server.Account, error) {
@ -462,11 +462,12 @@ func (am *MockAccountManager) CreateUser(accountID string, invite *server.UserIn
} }
// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface // GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface
func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User,
error) {
if am.GetAccountFromTokenFunc != nil { if am.GetAccountFromTokenFunc != nil {
return am.GetAccountFromTokenFunc(claims) return am.GetAccountFromTokenFunc(claims)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented") return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
} }
// GetPeers mocks GetPeers of the AccountManager interface // GetPeers mocks GetPeers of the AccountManager interface

View File

@ -3,10 +3,9 @@ package server
import ( import (
"github.com/miekg/dns" "github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/status"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"strconv" "strconv"
"unicode/utf8" "unicode/utf8"
) )
@ -66,7 +65,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string)
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
nsGroup, found := account.NameServerGroups[nsGroupID] nsGroup, found := account.NameServerGroups[nsGroupID]
@ -74,7 +73,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string)
return nsGroup.Copy(), nil return nsGroup.Copy(), nil
} }
return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID) return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
} }
// CreateNameServerGroup creates and saves a new nameserver group // CreateNameServerGroup creates and saves a new nameserver group
@ -85,7 +84,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
newNSGroup := &nbdns.NameServerGroup{ newNSGroup := &nbdns.NameServerGroup{
@ -119,7 +118,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
err = am.updateAccountPeers(account) err = am.updateAccountPeers(account)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return newNSGroup.Copy(), status.Errorf(codes.Unavailable, "failed to update peers after create nameserver %s", name) return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after create nameserver %s", name)
} }
return newNSGroup.Copy(), nil return newNSGroup.Copy(), nil
@ -132,12 +131,12 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupTo
defer unlock() defer unlock()
if nsGroupToSave == nil { if nsGroupToSave == nil {
return status.Errorf(codes.InvalidArgument, "nameserver group provided is nil") return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
} }
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return status.Errorf(codes.NotFound, "account not found") return err
} }
err = validateNameServerGroup(true, nsGroupToSave, account) err = validateNameServerGroup(true, nsGroupToSave, account)
@ -156,7 +155,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupTo
err = am.updateAccountPeers(account) err = am.updateAccountPeers(account)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return status.Errorf(codes.Unavailable, "failed to update peers after update nameserver %s", nsGroupToSave.Name) return status.Errorf(status.Internal, "failed to update peers after update nameserver %s", nsGroupToSave.Name)
} }
return nil return nil
@ -170,16 +169,16 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
if len(operations) == 0 { if len(operations) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "operations shouldn't be empty") return nil, status.Errorf(status.InvalidArgument, "operations shouldn't be empty")
} }
nsGroupToUpdate, ok := account.NameServerGroups[nsGroupID] nsGroupToUpdate, ok := account.NameServerGroups[nsGroupID]
if !ok { if !ok {
return nil, status.Errorf(codes.NotFound, "nameserver group ID %s no longer exists", nsGroupID) return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
} }
newNSGroup := nsGroupToUpdate.Copy() newNSGroup := nsGroupToUpdate.Copy()
@ -187,12 +186,12 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
for _, operation := range operations { for _, operation := range operations {
valuesCount := len(operation.Values) valuesCount := len(operation.Values)
if valuesCount < 1 { if valuesCount < 1 {
return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String()) return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String())
} }
for _, value := range operation.Values { for _, value := range operation.Values {
if value == "" { if value == "" {
return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String()) return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String())
} }
} }
switch operation.Type { switch operation.Type {
@ -200,7 +199,7 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
newNSGroup.Description = operation.Values[0] newNSGroup.Description = operation.Values[0]
case UpdateNameServerGroupName: case UpdateNameServerGroupName:
if valuesCount > 1 { if valuesCount > 1 {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount) return nil, status.Errorf(status.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount)
} }
err = validateNSGroupName(operation.Values[0], nsGroupID, account.NameServerGroups) err = validateNSGroupName(operation.Values[0], nsGroupID, account.NameServerGroups)
if err != nil { if err != nil {
@ -230,13 +229,13 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
case UpdateNameServerGroupEnabled: case UpdateNameServerGroupEnabled:
enabled, err := strconv.ParseBool(operation.Values[0]) enabled, err := strconv.ParseBool(operation.Values[0])
if err != nil { if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0]) return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
} }
newNSGroup.Enabled = enabled newNSGroup.Enabled = enabled
case UpdateNameServerGroupPrimary: case UpdateNameServerGroupPrimary:
primary, err := strconv.ParseBool(operation.Values[0]) primary, err := strconv.ParseBool(operation.Values[0])
if err != nil { if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0]) return nil, status.Errorf(status.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0])
} }
newNSGroup.Primary = primary newNSGroup.Primary = primary
case UpdateNameServerGroupDomains: case UpdateNameServerGroupDomains:
@ -259,7 +258,7 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
err = am.updateAccountPeers(account) err = am.updateAccountPeers(account)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return newNSGroup.Copy(), status.Errorf(codes.Unavailable, "failed to update peers after update nameserver %s", newNSGroup.Name) return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after update nameserver %s", newNSGroup.Name)
} }
return newNSGroup.Copy(), nil return newNSGroup.Copy(), nil
@ -273,7 +272,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID stri
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return status.Errorf(codes.NotFound, "account not found") return err
} }
delete(account.NameServerGroups, nsGroupID) delete(account.NameServerGroups, nsGroupID)
@ -287,7 +286,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID stri
err = am.updateAccountPeers(account) err = am.updateAccountPeers(account)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return status.Errorf(codes.Unavailable, "failed to update peers after deleting nameserver %s", nsGroupID) return status.Errorf(status.Internal, "failed to update peers after deleting nameserver %s", nsGroupID)
} }
return nil return nil
@ -301,7 +300,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdn
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups)) nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups))
@ -318,7 +317,7 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
nsGroupID = nameserverGroup.ID nsGroupID = nameserverGroup.ID
_, found := account.NameServerGroups[nsGroupID] _, found := account.NameServerGroups[nsGroupID]
if !found { if !found {
return status.Errorf(codes.NotFound, "nameserver group with ID %s was not found", nsGroupID) return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
} }
} }
@ -347,17 +346,17 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
func validateDomainInput(primary bool, domains []string) error { func validateDomainInput(primary bool, domains []string) error {
if !primary && len(domains) == 0 { if !primary && len(domains) == 0 {
return status.Errorf(codes.InvalidArgument, "nameserver group primary status is false and domains are empty,"+ return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
" it should be primary or have at least one domain") " it should be primary or have at least one domain")
} }
if primary && len(domains) != 0 { if primary && len(domains) != 0 {
return status.Errorf(codes.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+ return status.Errorf(status.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+
" you should set either primary or domain") " you should set either primary or domain")
} }
for _, domain := range domains { for _, domain := range domains {
_, valid := dns.IsDomainName(domain) _, valid := dns.IsDomainName(domain)
if !valid { if !valid {
return status.Errorf(codes.InvalidArgument, "nameserver group got an invalid domain: %s", domain) return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s", domain)
} }
} }
return nil return nil
@ -365,12 +364,12 @@ func validateDomainInput(primary bool, domains []string) error {
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error { func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" { if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
return status.Errorf(codes.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar) return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
} }
for _, nsGroup := range nsGroupMap { for _, nsGroup := range nsGroupMap {
if name == nsGroup.Name && nsGroup.ID != nsGroupID { if name == nsGroup.Name && nsGroup.ID != nsGroupID {
return status.Errorf(codes.InvalidArgument, "a nameserver group with name %s already exist", name) return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name)
} }
} }
@ -380,19 +379,19 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na
func validateNSList(list []nbdns.NameServer) error { func validateNSList(list []nbdns.NameServer) error {
nsListLenght := len(list) nsListLenght := len(list)
if nsListLenght == 0 || nsListLenght > 2 { if nsListLenght == 0 || nsListLenght > 2 {
return status.Errorf(codes.InvalidArgument, "the list of nameservers should be 1 or 2, got %d", len(list)) return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 2, got %d", len(list))
} }
return nil return nil
} }
func validateGroups(list []string, groups map[string]*Group) error { func validateGroups(list []string, groups map[string]*Group) error {
if len(list) == 0 { if len(list) == 0 {
return status.Errorf(codes.InvalidArgument, "the list of group IDs should not be empty") return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty")
} }
for _, id := range list { for _, id := range list {
if id == "" { if id == "" {
return status.Errorf(codes.InvalidArgument, "group ID should not be empty string") return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
} }
found := false found := false
for groupID := range groups { for groupID := range groups {
@ -402,7 +401,7 @@ func validateGroups(list []string, groups map[string]*Group) error {
} }
} }
if !found { if !found {
return status.Errorf(codes.InvalidArgument, "group id %s not found", id) return status.Errorf(status.InvalidArgument, "group id %s not found", id)
} }
} }

View File

@ -3,6 +3,7 @@ package server
import ( import (
"github.com/c-robinson/iplib" "github.com/c-robinson/iplib"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/rs/xid" "github.com/rs/xid"
"math/rand" "math/rand"
@ -93,7 +94,7 @@ func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
ips, _ := generateIPs(&ipNet, takenIPMap) ips, _ := generateIPs(&ipNet, takenIPMap)
if len(ips) == 0 { if len(ips) == 0 {
return nil, Errorf(PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String()) return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String())
} }
// pick a random IP // pick a random IP

View File

@ -2,6 +2,7 @@ package server
import ( import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/status"
"net" "net"
"strings" "strings"
"time" "time"
@ -9,8 +10,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
) )
// PeerSystemMeta is a metadata of a Peer machine system // PeerSystemMeta is a metadata of a Peer machine system
@ -162,7 +161,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Pe
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
//TODO Peer.ID migration: we will need to replace search by ID here //TODO Peer.ID migration: we will need to replace search by ID here
@ -208,7 +207,7 @@ func (am *DefaultAccountManager) DeletePeer(accountID string, peerPubKey string)
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
peer, err := account.FindPeerByPubKey(peerPubKey) peer, err := account.FindPeerByPubKey(peerPubKey)
@ -258,7 +257,7 @@ func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
for _, peer := range account.Peers { for _, peer := range account.Peers {
@ -267,7 +266,7 @@ func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*
} }
} }
return nil, status.Errorf(codes.NotFound, "peer with IP %s not found", peerIP) return nil, status.Errorf(status.NotFound, "peer with IP %s not found", peerIP)
} }
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
@ -275,7 +274,7 @@ func (am *DefaultAccountManager) GetNetworkMap(peerPubKey string) (*NetworkMap,
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey) account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "Invalid peer key %s", peerPubKey) return nil, err
} }
aclPeers := am.getPeersByACL(account, peerPubKey) aclPeers := am.getPeersByACL(account, peerPubKey)
@ -306,7 +305,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(peerPubKey string) (*Network, er
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey) account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "invalid peer key %s", peerPubKey) return nil, err
} }
return account.Network.Copy(), err return account.Network.Copy(), err
@ -332,7 +331,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
account, err = am.Store.GetAccountBySetupKey(setupKey) account, err = am.Store.GetAccountBySetupKey(setupKey)
} }
if err != nil { if err != nil {
return nil, Errorf(AccountNotFound, "failed adding new peer: account not found") return nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
} }
unlock := am.Store.AcquireAccountLock(account.Id) unlock := am.Store.AcquireAccountLock(account.Id)
@ -352,7 +351,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
} }
if !sk.IsValid() { if !sk.IsValid() {
return nil, Errorf(PreconditionFailed, "couldn't add peer: setup key is invalid") return nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
} }
account.SetupKeys[sk.Key] = sk.IncrementUsage() account.SetupKeys[sk.Key] = sk.IncrementUsage()
@ -418,7 +417,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
account.Network.IncSerial() account.Network.IncSerial()
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(account)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed adding peer") return nil, err
} }
return newPeer, nil return newPeer, nil
@ -563,13 +562,13 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) error {
for _, peer := range peers { for _, peer := range peers {
remotePeerNetworkMap, err := am.GetNetworkMap(peer.Key) remotePeerNetworkMap, err := am.GetNetworkMap(peer.Key)
if err != nil { if err != nil {
return status.Errorf(codes.Internal, "unable to fetch network map for peer %s, error: %v", peer.Key, err) return err
} }
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap) update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap)
err = am.peersUpdateManager.SendUpdate(peer.Key, &UpdateMessage{Update: update}) err = am.peersUpdateManager.SendUpdate(peer.Key, &UpdateMessage{Update: update})
if err != nil { if err != nil {
return status.Errorf(codes.Internal, "unable to send update for peer %s, error: %v", peer.Key, err) return err
} }
} }

View File

@ -2,11 +2,10 @@ package server
import ( import (
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/netip" "net/netip"
"strconv" "strconv"
"unicode/utf8" "unicode/utf8"
@ -66,7 +65,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
user, err := account.FindUser(userID) user, err := account.FindUser(userID)
@ -75,7 +74,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
} }
if !user.IsAdmin() { if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes") return nil, status.Errorf(status.PermissionDenied, "Only administrators can view Network Routes")
} }
wantedRoute, found := account.Routes[routeID] wantedRoute, found := account.Routes[routeID]
@ -83,7 +82,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
return wantedRoute, nil return wantedRoute, nil
} }
return nil, status.Errorf(codes.NotFound, "route with ID %s not found", routeID) return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
} }
// checkPrefixPeerExists checks the combination of prefix and peer id, if it exists returns an error, otherwise returns nil // checkPrefixPeerExists checks the combination of prefix and peer id, if it exists returns an error, otherwise returns nil
@ -101,14 +100,14 @@ func (am *DefaultAccountManager) checkPrefixPeerExists(accountID, peer string, p
routesWithPrefix := account.GetRoutesByPrefix(prefix) routesWithPrefix := account.GetRoutesByPrefix(prefix)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
return nil return nil
} }
return status.Errorf(codes.InvalidArgument, "failed to parse prefix %s", prefix.String()) return status.Errorf(status.InvalidArgument, "failed to parse prefix %s", prefix.String())
} }
for _, prefixRoute := range routesWithPrefix { for _, prefixRoute := range routesWithPrefix {
if prefixRoute.Peer == peer { if prefixRoute.Peer == peer {
return status.Errorf(codes.AlreadyExists, "failed a route with prefix %s and peer already exist", prefix.String()) return status.Errorf(status.AlreadyExists, "failed a route with prefix %s and peer already exist", prefix.String())
} }
} }
return nil return nil
@ -121,13 +120,13 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
var newRoute route.Route var newRoute route.Route
prefixType, newPrefix, err := route.ParseNetwork(network) prefixType, newPrefix, err := route.ParseNetwork(network)
if err != nil { if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse IP %s", network) return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", network)
} }
err = am.checkPrefixPeerExists(accountID, peer, newPrefix) err = am.checkPrefixPeerExists(accountID, peer, newPrefix)
if err != nil { if err != nil {
@ -137,16 +136,16 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de
if peer != "" { if peer != "" {
_, peerExist := account.Peers[peer] _, peerExist := account.Peers[peer]
if !peerExist { if !peerExist {
return nil, status.Errorf(codes.InvalidArgument, "failed to find Peer %s", peer) return nil, status.Errorf(status.InvalidArgument, "failed to find Peer %s", peer)
} }
} }
if metric < route.MinMetric || metric > route.MaxMetric { if metric < route.MinMetric || metric > route.MaxMetric {
return nil, status.Errorf(codes.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
} }
if utf8.RuneCountInString(netID) > route.MaxNetIDChar || netID == "" { if utf8.RuneCountInString(netID) > route.MaxNetIDChar || netID == "" {
return nil, status.Errorf(codes.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
} }
newRoute.Peer = peer newRoute.Peer = peer
@ -173,7 +172,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de
err = am.updateAccountPeers(account) err = am.updateAccountPeers(account)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return &newRoute, status.Errorf(codes.Unavailable, "failed to update peers after create route %s", newPrefix) return &newRoute, status.Errorf(status.Internal, "failed to update peers after create route %s", newPrefix)
} }
return &newRoute, nil return &newRoute, nil
} }
@ -184,30 +183,30 @@ func (am *DefaultAccountManager) SaveRoute(accountID string, routeToSave *route.
defer unlock() defer unlock()
if routeToSave == nil { if routeToSave == nil {
return status.Errorf(codes.InvalidArgument, "route provided is nil") return status.Errorf(status.InvalidArgument, "route provided is nil")
} }
if !routeToSave.Network.IsValid() { if !routeToSave.Network.IsValid() {
return status.Errorf(codes.InvalidArgument, "invalid Prefix %s", routeToSave.Network.String()) return status.Errorf(status.InvalidArgument, "invalid Prefix %s", routeToSave.Network.String())
} }
if routeToSave.Metric < route.MinMetric || routeToSave.Metric > route.MaxMetric { if routeToSave.Metric < route.MinMetric || routeToSave.Metric > route.MaxMetric {
return status.Errorf(codes.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
} }
if utf8.RuneCountInString(routeToSave.NetID) > route.MaxNetIDChar || routeToSave.NetID == "" { if utf8.RuneCountInString(routeToSave.NetID) > route.MaxNetIDChar || routeToSave.NetID == "" {
return status.Errorf(codes.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
} }
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return status.Errorf(codes.NotFound, "account not found") return err
} }
if routeToSave.Peer != "" { if routeToSave.Peer != "" {
_, peerExist := account.Peers[routeToSave.Peer] _, peerExist := account.Peers[routeToSave.Peer]
if !peerExist { if !peerExist {
return status.Errorf(codes.InvalidArgument, "failed to find Peer %s", routeToSave.Peer) return status.Errorf(status.InvalidArgument, "failed to find Peer %s", routeToSave.Peer)
} }
} }
@ -228,12 +227,12 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
routeToUpdate, ok := account.Routes[routeID] routeToUpdate, ok := account.Routes[routeID]
if !ok { if !ok {
return nil, status.Errorf(codes.NotFound, "route %s no longer exists", routeID) return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID)
} }
newRoute := routeToUpdate.Copy() newRoute := routeToUpdate.Copy()
@ -241,7 +240,7 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
for _, operation := range operations { for _, operation := range operations {
if len(operation.Values) != 1 { if len(operation.Values) != 1 {
return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid number of values, it should be 1", operation.Type.String()) return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be 1", operation.Type.String())
} }
switch operation.Type { switch operation.Type {
@ -249,13 +248,13 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
newRoute.Description = operation.Values[0] newRoute.Description = operation.Values[0]
case UpdateRouteNetworkIdentifier: case UpdateRouteNetworkIdentifier:
if utf8.RuneCountInString(operation.Values[0]) > route.MaxNetIDChar || operation.Values[0] == "" { if utf8.RuneCountInString(operation.Values[0]) > route.MaxNetIDChar || operation.Values[0] == "" {
return nil, status.Errorf(codes.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
} }
newRoute.NetID = operation.Values[0] newRoute.NetID = operation.Values[0]
case UpdateRouteNetwork: case UpdateRouteNetwork:
prefixType, prefix, err := route.ParseNetwork(operation.Values[0]) prefixType, prefix, err := route.ParseNetwork(operation.Values[0])
if err != nil { if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse IP %s", operation.Values[0]) return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", operation.Values[0])
} }
err = am.checkPrefixPeerExists(accountID, routeToUpdate.Peer, prefix) err = am.checkPrefixPeerExists(accountID, routeToUpdate.Peer, prefix)
if err != nil { if err != nil {
@ -267,7 +266,7 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
if operation.Values[0] != "" { if operation.Values[0] != "" {
_, peerExist := account.Peers[operation.Values[0]] _, peerExist := account.Peers[operation.Values[0]]
if !peerExist { if !peerExist {
return nil, status.Errorf(codes.InvalidArgument, "failed to find Peer %s", operation.Values[0]) return nil, status.Errorf(status.InvalidArgument, "failed to find Peer %s", operation.Values[0])
} }
} }
@ -279,10 +278,10 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
case UpdateRouteMetric: case UpdateRouteMetric:
metric, err := strconv.Atoi(operation.Values[0]) metric, err := strconv.Atoi(operation.Values[0])
if err != nil { if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse metric %s, not int", operation.Values[0]) return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, not int", operation.Values[0])
} }
if metric < route.MinMetric || metric > route.MaxMetric { if metric < route.MinMetric || metric > route.MaxMetric {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse metric %s, value should be %d > N < %d", return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, value should be %d > N < %d",
operation.Values[0], operation.Values[0],
route.MinMetric, route.MinMetric,
route.MaxMetric, route.MaxMetric,
@ -292,13 +291,13 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
case UpdateRouteMasquerade: case UpdateRouteMasquerade:
masquerade, err := strconv.ParseBool(operation.Values[0]) masquerade, err := strconv.ParseBool(operation.Values[0])
if err != nil { if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse masquerade %s, not boolean", operation.Values[0]) return nil, status.Errorf(status.InvalidArgument, "failed to parse masquerade %s, not boolean", operation.Values[0])
} }
newRoute.Masquerade = masquerade newRoute.Masquerade = masquerade
case UpdateRouteEnabled: case UpdateRouteEnabled:
enabled, err := strconv.ParseBool(operation.Values[0]) enabled, err := strconv.ParseBool(operation.Values[0])
if err != nil { if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0]) return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
} }
newRoute.Enabled = enabled newRoute.Enabled = enabled
} }
@ -313,7 +312,7 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
err = am.updateAccountPeers(account) err = am.updateAccountPeers(account)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update account peers") return nil, status.Errorf(status.Internal, "failed to update account peers")
} }
return newRoute, nil return newRoute, nil
} }
@ -325,7 +324,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error {
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return status.Errorf(codes.NotFound, "account not found") return err
} }
delete(account.Routes, routeID) delete(account.Routes, routeID)
@ -345,7 +344,7 @@ func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
user, err := account.FindUser(userID) user, err := account.FindUser(userID)
@ -354,7 +353,7 @@ func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.
} }
if !user.IsAdmin() { if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes") return nil, status.Errorf(status.PermissionDenied, "Only administrators can view Network Routes")
} }
routes := make([]*route.Route, 0, len(account.Routes)) routes := make([]*route.Route, 0, len(account.Routes))

View File

@ -1,8 +1,7 @@
package server package server
import ( import (
"google.golang.org/grpc/codes" "github.com/netbirdio/netbird/management/server/status"
"google.golang.org/grpc/status"
"strings" "strings"
) )
@ -95,7 +94,7 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rul
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
user, err := account.FindUser(userID) user, err := account.FindUser(userID)
@ -104,7 +103,7 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rul
} }
if !user.IsAdmin() { if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "only admins are allowed to view rules") return nil, status.Errorf(status.PermissionDenied, "only admins are allowed to view rules")
} }
rule, ok := account.Rules[ruleID] rule, ok := account.Rules[ruleID]
@ -112,7 +111,7 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rul
return rule, nil return rule, nil
} }
return nil, status.Errorf(codes.NotFound, "rule with ID %s not found", ruleID) return nil, status.Errorf(status.NotFound, "rule with ID %s not found", ruleID)
} }
// SaveRule of ACL in the store // SaveRule of ACL in the store
@ -122,7 +121,7 @@ func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error {
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return status.Errorf(codes.NotFound, "account not found") return err
} }
account.Rules[rule.ID] = rule account.Rules[rule.ID] = rule
@ -143,12 +142,12 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
ruleToUpdate, ok := account.Rules[ruleID] ruleToUpdate, ok := account.Rules[ruleID]
if !ok { if !ok {
return nil, status.Errorf(codes.NotFound, "rule %s no longer exists", ruleID) return nil, status.Errorf(status.NotFound, "rule %s no longer exists", ruleID)
} }
rule := ruleToUpdate.Copy() rule := ruleToUpdate.Copy()
@ -161,7 +160,7 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
rule.Description = operation.Values[0] rule.Description = operation.Values[0]
case UpdateRuleFlow: case UpdateRuleFlow:
if operation.Values[0] != TrafficFlowBidirectString { if operation.Values[0] != TrafficFlowBidirectString {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse flow") return nil, status.Errorf(status.InvalidArgument, "failed to parse flow")
} }
rule.Flow = TrafficFlowBidirect rule.Flow = TrafficFlowBidirect
case UpdateRuleStatus: case UpdateRuleStatus:
@ -170,7 +169,7 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
} else if strings.ToLower(operation.Values[0]) == "false" { } else if strings.ToLower(operation.Values[0]) == "false" {
rule.Disabled = false rule.Disabled = false
} else { } else {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse status") return nil, status.Errorf(status.InvalidArgument, "failed to parse status")
} }
case UpdateSourceGroups: case UpdateSourceGroups:
rule.Source = operation.Values rule.Source = operation.Values
@ -204,7 +203,7 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
err = am.updateAccountPeers(account) err = am.updateAccountPeers(account)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update account peers") return nil, status.Errorf(status.Internal, "failed to update account peers")
} }
return rule, nil return rule, nil
@ -217,7 +216,7 @@ func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error {
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return status.Errorf(codes.NotFound, "account not found") return err
} }
delete(account.Rules, ruleID) delete(account.Rules, ruleID)
@ -237,7 +236,7 @@ func (am *DefaultAccountManager) ListRules(accountID, userID string) ([]*Rule, e
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
user, err := account.FindUser(userID) user, err := account.FindUser(userID)
@ -246,7 +245,7 @@ func (am *DefaultAccountManager) ListRules(accountID, userID string) ([]*Rule, e
} }
if !user.IsAdmin() { if !user.IsAdmin() {
return nil, Errorf(PermissionDenied, "Only Administrators can view Access Rules") return nil, status.Errorf(status.PermissionDenied, "Only Administrators can view Access Rules")
} }
rules := make([]*Rule, 0, len(account.Rules)) rules := make([]*Rule, 0, len(account.Rules))

View File

@ -1,10 +1,8 @@
package server package server
import ( import (
"fmt"
"github.com/google/uuid" "github.com/google/uuid"
"google.golang.org/grpc/codes" "github.com/netbirdio/netbird/management/server/status"
"google.golang.org/grpc/status"
"hash/fnv" "hash/fnv"
"strconv" "strconv"
"strings" "strings"
@ -183,12 +181,12 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
for _, group := range autoGroups { for _, group := range autoGroups {
if _, ok := account.Groups[group]; !ok { if _, ok := account.Groups[group]; !ok {
return nil, fmt.Errorf("group %s doesn't exist", group) return nil, status.Errorf(status.NotFound, "group %s doesn't exist", group)
} }
} }
@ -197,7 +195,7 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(account)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed adding account key") return nil, status.Errorf(status.Internal, "failed adding account key")
} }
return setupKey, nil return setupKey, nil
@ -212,12 +210,12 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
defer unlock() defer unlock()
if keyToSave == nil { if keyToSave == nil {
return nil, status.Errorf(codes.InvalidArgument, "provided setup key to update is nil") return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
} }
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
var oldKey *SetupKey var oldKey *SetupKey
@ -228,7 +226,7 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
} }
} }
if oldKey == nil { if oldKey == nil {
return nil, status.Errorf(codes.NotFound, "setup key not found") return nil, status.Errorf(status.NotFound, "setup key not found")
} }
// only auto groups, revoked status, and name can be updated for now // only auto groups, revoked status, and name can be updated for now
@ -253,7 +251,7 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set
defer unlock() defer unlock()
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
user, err := account.FindUser(userID) user, err := account.FindUser(userID)
@ -282,7 +280,7 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
user, err := account.FindUser(userID) user, err := account.FindUser(userID)
@ -298,7 +296,7 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*
} }
} }
if foundKey == nil { if foundKey == nil {
return nil, status.Errorf(codes.NotFound, "setup key not found") return nil, status.Errorf(status.NotFound, "setup key not found")
} }
// the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file) // the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file)

View File

@ -0,0 +1,72 @@
package status
import (
"fmt"
)
const (
// UserAlreadyExists indicates that user already exists
UserAlreadyExists Type = 1
// PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled
PreconditionFailed Type = 2
// PermissionDenied indicates that user has no permissions to view data
PermissionDenied Type = 3
// NotFound indicates that the object wasn't found in the system (or under a given Account)
NotFound Type = 4
// Internal indicates some generic internal error
Internal Type = 5
// InvalidArgument indicates some generic invalid argument error
InvalidArgument Type = 6
// AlreadyExists indicates a generic error when an object already exists in the system
AlreadyExists Type = 7
// Unauthorized indicates that user is not authorized
Unauthorized Type = 8
// BadRequest indicates that user is not authorized
BadRequest Type = 9
)
// Type is a type of the Error
type Type int32
// Error is an internal error
type Error struct {
ErrorType Type
Message string
}
// Type returns the Type of the error
func (e *Error) Type() Type {
return e.ErrorType
}
// Error is an error string
func (e *Error) Error() string {
return e.Message
}
// Errorf returns Error(ErrorType, fmt.Sprintf(format, a...)).
func Errorf(errorType Type, format string, a ...interface{}) error {
return &Error{
ErrorType: errorType,
Message: fmt.Sprintf(format, a...),
}
}
// FromError returns Error, true if the provided error is of type of Error. nil, false otherwise
func FromError(err error) (s *Error, ok bool) {
if err == nil {
return nil, true
}
if e, ok := err.(*Error); ok {
return e, true
}
return nil, false
}

View File

@ -3,8 +3,7 @@ package server
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"google.golang.org/grpc/codes" "github.com/netbirdio/netbird/management/server/status"
"google.golang.org/grpc/status"
"strings" "strings"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@ -123,7 +122,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
defer unlock() defer unlock()
if am.idpManager == nil { if am.idpManager == nil {
return nil, Errorf(PreconditionFailed, "IdP manager must be enabled to send user invites") return nil, status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites")
} }
if invite == nil { if invite == nil {
@ -132,7 +131,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, Errorf(AccountNotFound, "account %s doesn't exist", accountID) return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
} }
// check if the user is already registered with this email => reject // check if the user is already registered with this email => reject
@ -142,7 +141,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
} }
if user != nil { if user != nil {
return nil, Errorf(UserAlreadyExists, "user has an existing account") return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
} }
users, err := am.idpManager.GetUserByEmail(invite.Email) users, err := am.idpManager.GetUserByEmail(invite.Email)
@ -151,7 +150,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
} }
if len(users) > 0 { if len(users) > 0 {
return nil, Errorf(UserAlreadyExists, "user has an existing account") return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
} }
idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID) idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID)
@ -188,25 +187,24 @@ func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*User
defer unlock() defer unlock()
if update == nil { if update == nil {
return nil, status.Errorf(codes.InvalidArgument, "provided user update is nil") return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
} }
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found") return nil, err
} }
for _, newGroupID := range update.AutoGroups { for _, newGroupID := range update.AutoGroups {
if _, ok := account.Groups[newGroupID]; !ok { if _, ok := account.Groups[newGroupID]; !ok {
return nil, return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
status.Errorf(codes.InvalidArgument, "provided group ID %s in the user %s update doesn't exist", newGroupID, update.Id)
newGroupID, update.Id)
} }
} }
oldUser := account.Users[update.Id] oldUser := account.Users[update.Id]
if oldUser == nil { if oldUser == nil {
return nil, status.Errorf(codes.NotFound, "update not found") return nil, status.Errorf(status.NotFound, "update not found")
} }
// only auto groups, revoked status, and name can be updated for now // only auto groups, revoked status, and name can be updated for now
@ -226,7 +224,7 @@ func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*User
return nil, err return nil, err
} }
if userData == nil { if userData == nil {
return nil, status.Errorf(codes.NotFound, "user %s not found in the IdP", newUser.Id) return nil, status.Errorf(status.NotFound, "user %s not found in the IdP", newUser.Id)
} }
return newUser.toUserInfo(userData) return newUser.toUserInfo(userData)
} }
@ -242,14 +240,14 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
account, err := am.Store.GetAccountByUser(userID) account, err := am.Store.GetAccountByUser(userID)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err = am.newAccount(userID, lowerDomain) account, err = am.newAccount(userID, lowerDomain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(account)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed creating account") return nil, err
} }
} else { } else {
// other error // other error
@ -263,7 +261,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
account.Domain = lowerDomain account.Domain = lowerDomain
err = am.Store.SaveAccount(account) err = am.Store.SaveAccount(account)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed updating account with domain") return nil, status.Errorf(status.Internal, "failed updating account with domain")
} }
} }
@ -272,14 +270,14 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
// IsUserAdmin flag for current user authenticated by JWT token // IsUserAdmin flag for current user authenticated by JWT token
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) { func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
account, err := am.GetAccountFromToken(claims) account, _, err := am.GetAccountFromToken(claims)
if err != nil { if err != nil {
return false, fmt.Errorf("get account: %v", err) return false, fmt.Errorf("get account: %v", err)
} }
user, ok := account.Users[claims.UserId] user, ok := account.Users[claims.UserId]
if !ok { if !ok {
return false, fmt.Errorf("no such user") return false, status.Errorf(status.NotFound, "user not found")
} }
return user.Role == UserRoleAdmin, nil return user.Role == UserRoleAdmin, nil

View File

@ -1,8 +1,7 @@
package route package route
import ( import (
"google.golang.org/grpc/codes" "github.com/netbirdio/netbird/management/server/status"
"google.golang.org/grpc/status"
"net/netip" "net/netip"
) )
@ -108,13 +107,13 @@ func (r *Route) IsEqual(other *Route) bool {
func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) { func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) {
prefix, err := netip.ParsePrefix(networkString) prefix, err := netip.ParsePrefix(networkString)
if err != nil { if err != nil {
return InvalidNetwork, netip.Prefix{}, err return InvalidNetwork, netip.Prefix{}, status.Errorf(status.InvalidArgument, "invalid network %s", networkString)
} }
masked := prefix.Masked() masked := prefix.Masked()
if !masked.IsValid() { if !masked.IsValid() {
return InvalidNetwork, netip.Prefix{}, status.Errorf(codes.InvalidArgument, "invalid range %s", networkString) return InvalidNetwork, netip.Prefix{}, status.Errorf(status.InvalidArgument, "invalid range %s", networkString)
} }
if masked.Addr().Is6() { if masked.Addr().Is6() {