mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-21 23:53:14 +01:00
Replace gRPC errors in business logic with internal ones (#558)
This commit is contained in:
parent
1db4027bea
commit
509d23c7cf
@ -8,12 +8,11 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
@ -52,7 +51,7 @@ type AccountManager interface {
|
||||
SaveUser(accountID string, key *User) (*UserInfo, error)
|
||||
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
||||
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
|
||||
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error)
|
||||
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
||||
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
AccountExists(accountId string) (*bool, error)
|
||||
GetPeer(peerKey string) (*Peer, error)
|
||||
@ -265,14 +264,14 @@ func (a *Account) FindPeerByPubKey(peerPubKey string) (*Peer, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.NotFound, "peer with the public key %s not found", peerPubKey)
|
||||
return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", peerPubKey)
|
||||
}
|
||||
|
||||
// FindUser looks for a given user in the Account or returns error if user wasn't found.
|
||||
func (a *Account) FindUser(userID string) (*User, error) {
|
||||
user := a.Users[userID]
|
||||
if user == nil {
|
||||
return nil, Errorf(UserNotFound, "user %s not found", userID)
|
||||
return nil, status.Errorf(status.NotFound, "user %s not found", userID)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
@ -282,7 +281,7 @@ func (a *Account) FindUser(userID string) (*User, error) {
|
||||
func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) {
|
||||
key := a.SetupKeys[setupKey]
|
||||
if key == nil {
|
||||
return nil, Errorf(SetupKeyNotFound, "setup key not found")
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
}
|
||||
|
||||
return key, nil
|
||||
@ -458,14 +457,14 @@ func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, er
|
||||
if err == nil {
|
||||
log.Warnf("an account with ID already exists, retrying...")
|
||||
continue
|
||||
} else if statusErr.Code() == codes.NotFound {
|
||||
} else if statusErr.Type() == status.NotFound {
|
||||
return newAccountWithId(accountId, userID, domain), nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.Internal, "error while creating new account")
|
||||
return nil, status.Errorf(status.Internal, "error while creating new account")
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) warmupIDPCache() error {
|
||||
@ -492,7 +491,7 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID,
|
||||
} else if userID != "" {
|
||||
account, err := am.GetOrCreateAccountByUser(userID, domain)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userID)
|
||||
return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID)
|
||||
}
|
||||
err = am.addAccountIDToIDPAppMeta(userID, account)
|
||||
if err != nil {
|
||||
@ -501,7 +500,7 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID,
|
||||
return account, nil
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.NotFound, "no valid user or account Id provided")
|
||||
return nil, status.Errorf(status.NotFound, "no valid user or account Id provided")
|
||||
}
|
||||
|
||||
func isNil(i idp.Manager) bool {
|
||||
@ -531,11 +530,7 @@ func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return status.Errorf(
|
||||
codes.Internal,
|
||||
"updating user's app metadata failed with: %v",
|
||||
err,
|
||||
)
|
||||
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
|
||||
}
|
||||
// refresh cache to reflect the update
|
||||
_, err = am.refreshCache(account.Id)
|
||||
@ -662,11 +657,8 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a
|
||||
}
|
||||
|
||||
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
|
||||
func (am *DefaultAccountManager) updateAccountDomainAttributes(
|
||||
account *Account,
|
||||
claims jwtclaims.AuthorizationClaims,
|
||||
primaryDomain bool,
|
||||
) error {
|
||||
func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims,
|
||||
primaryDomain bool) error {
|
||||
account.IsDomainPrimaryAccount = primaryDomain
|
||||
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
@ -681,7 +673,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(
|
||||
|
||||
err := am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed saving updated account")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -723,10 +715,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
|
||||
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
||||
// otherwise it will create a new account and make it primary account for the domain.
|
||||
func (am *DefaultAccountManager) handleNewUserAccount(
|
||||
domainAcc *Account,
|
||||
claims jwtclaims.AuthorizationClaims,
|
||||
) (*Account, error) {
|
||||
func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
||||
var (
|
||||
account *Account
|
||||
err error
|
||||
@ -738,7 +727,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(
|
||||
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed saving updated account")
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
account, err = am.newAccount(claims.UserId, lowerDomain)
|
||||
@ -773,7 +762,7 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return status.Errorf(codes.NotFound, "user %s not found in the IdP", userID)
|
||||
return status.Errorf(status.NotFound, "user %s not found in the IdP", userID)
|
||||
}
|
||||
|
||||
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
|
||||
@ -794,7 +783,7 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
|
||||
}
|
||||
|
||||
// GetAccountFromToken returns an account associated with this token
|
||||
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
||||
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
|
||||
|
||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||
// This section is mostly related to self-hosted installations.
|
||||
@ -806,15 +795,21 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
||||
|
||||
account, err := am.getAccountWithAuthorizationClaims(claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
user := account.Users[claims.UserId]
|
||||
if user == nil {
|
||||
// this is not really possible because we got an account by user ID
|
||||
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||
}
|
||||
|
||||
err = am.redeemInvite(account, claims.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return account, nil
|
||||
return account, user, nil
|
||||
}
|
||||
|
||||
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
|
||||
@ -857,10 +852,13 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
|
||||
|
||||
// We checked if the domain has a primary account already
|
||||
domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain)
|
||||
accStatus, _ := status.FromError(err)
|
||||
if accStatus.Code() != codes.OK && accStatus.Code() != codes.NotFound {
|
||||
if err != nil {
|
||||
// if NotFound we are good to continue, otherwise return error
|
||||
e, ok := status.FromError(err)
|
||||
if !ok || e.Type() != status.NotFound {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccountByUser(claims.UserId)
|
||||
if err == nil {
|
||||
@ -869,7 +867,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
|
||||
return nil, err
|
||||
}
|
||||
return account, nil
|
||||
} else if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
return am.handleNewUserAccount(domainAccount, claims)
|
||||
} else {
|
||||
// other error
|
||||
@ -891,7 +889,7 @@ func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error)
|
||||
var res bool
|
||||
_, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
res = false
|
||||
return &res, nil
|
||||
} else {
|
||||
|
@ -314,7 +314,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
||||
testCase.inputClaims.AccountId = initAccount.Id
|
||||
}
|
||||
|
||||
account, err := manager.GetAccountFromToken(testCase.inputClaims)
|
||||
account, _, err := manager.GetAccountFromToken(testCase.inputClaims)
|
||||
require.NoError(t, err, "support function failed")
|
||||
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
|
||||
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
|
||||
|
@ -1,61 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
// UserAlreadyExists indicates that user already exists
|
||||
UserAlreadyExists ErrorType = iota
|
||||
// AccountNotFound indicates that specified account hasn't been found
|
||||
AccountNotFound
|
||||
// PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled
|
||||
PreconditionFailed
|
||||
|
||||
// UserNotFound indicates that user wasn't found in the system (or under a given Account)
|
||||
UserNotFound
|
||||
|
||||
// PermissionDenied indicates that user has no permissions to view data
|
||||
PermissionDenied
|
||||
|
||||
// SetupKeyNotFound indicates that the setup key wasn't found in the system (or under a given Account)
|
||||
SetupKeyNotFound
|
||||
)
|
||||
|
||||
// ErrorType is a type of the Error
|
||||
type ErrorType int32
|
||||
|
||||
// Error is an internal error
|
||||
type Error struct {
|
||||
errorType ErrorType
|
||||
message string
|
||||
}
|
||||
|
||||
// Type returns the Type of the error
|
||||
func (e *Error) Type() ErrorType {
|
||||
return e.errorType
|
||||
}
|
||||
|
||||
// Error is an error string
|
||||
func (e *Error) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
// Errorf returns Error(errorType, fmt.Sprintf(format, a...)).
|
||||
func Errorf(errorType ErrorType, format string, a ...interface{}) error {
|
||||
return &Error{
|
||||
errorType: errorType,
|
||||
message: fmt.Sprintf(format, a...),
|
||||
}
|
||||
}
|
||||
|
||||
// FromError returns Error, true if the provided error is of type of Error. nil, false otherwise
|
||||
func FromError(err error) (s *Error, ok bool) {
|
||||
if err == nil {
|
||||
return nil, true
|
||||
}
|
||||
if e, ok := err.(*Error); ok {
|
||||
return e, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@ -8,9 +9,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@ -192,10 +190,7 @@ func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
||||
|
||||
accountID, accountIDFound := s.PrivateDomain2AccountID[strings.ToLower(domain)]
|
||||
if !accountIDFound {
|
||||
return nil, status.Errorf(
|
||||
codes.NotFound,
|
||||
"account not found: provided domain is not registered or is not private",
|
||||
)
|
||||
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
||||
}
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
@ -213,7 +208,7 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
||||
|
||||
accountID, accountIDFound := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
||||
if !accountIDFound {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found: provided setup key doesn't exists")
|
||||
return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
|
||||
}
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
@ -239,7 +234,7 @@ func (s *FileStore) GetAllAccounts() (all []*Account) {
|
||||
func (s *FileStore) getAccount(accountID string) (*Account, error) {
|
||||
account, accountFound := s.Accounts[accountID]
|
||||
if !accountFound {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
|
||||
return account, nil
|
||||
@ -265,7 +260,7 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
|
||||
|
||||
accountID, accountIDFound := s.UserID2AccountID[userID]
|
||||
if !accountIDFound {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
@ -283,7 +278,7 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
|
||||
accountID, accountIDFound := s.PeerKeyID2AccountID[peerKey]
|
||||
if !accountIDFound {
|
||||
return nil, status.Errorf(codes.NotFound, "Provided peer key doesn't exists %s", peerKey)
|
||||
return nil, status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
|
||||
}
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
@ -322,7 +317,7 @@ func (s *FileStore) SavePeerStatus(accountID, peerKey string, peerStatus PeerSta
|
||||
|
||||
peer := account.Peers[peerKey]
|
||||
if peer == nil {
|
||||
return status.Errorf(codes.NotFound, "peer %s not found", peerKey)
|
||||
return status.Errorf(status.NotFound, "peer %s not found", peerKey)
|
||||
}
|
||||
|
||||
peer.Status = &peerStatus
|
||||
|
@ -1,9 +1,6 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
import "github.com/netbirdio/netbird/management/server/status"
|
||||
|
||||
// Group of the peers for ACL
|
||||
type Group struct {
|
||||
@ -53,7 +50,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
@ -61,7 +58,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er
|
||||
return group, nil
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.NotFound, "group with ID %s not found", groupID)
|
||||
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||
}
|
||||
|
||||
// SaveGroup object of the peers
|
||||
@ -72,7 +69,7 @@ func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.NotFound, "account not found")
|
||||
return err
|
||||
}
|
||||
|
||||
account.Groups[group.ID] = group
|
||||
@ -94,12 +91,12 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string,
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupToUpdate, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.NotFound, "group %s no longer exists", groupID)
|
||||
return nil, status.Errorf(status.NotFound, "group with ID %s no longer exists", groupID)
|
||||
}
|
||||
|
||||
group := groupToUpdate.Copy()
|
||||
@ -130,7 +127,7 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string,
|
||||
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update account peers")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return group, nil
|
||||
@ -144,7 +141,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.NotFound, "account not found")
|
||||
return err
|
||||
}
|
||||
|
||||
delete(account.Groups, groupID)
|
||||
@ -165,7 +162,7 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error)
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groups := make([]*Group, 0, len(account.Groups))
|
||||
@ -184,12 +181,12 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.NotFound, "account not found")
|
||||
return err
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return status.Errorf(codes.NotFound, "group with ID %s not found", groupID)
|
||||
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||
}
|
||||
|
||||
add := true
|
||||
@ -219,12 +216,12 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.NotFound, "account not found")
|
||||
return err
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return status.Errorf(codes.NotFound, "group with ID %s not found", groupID)
|
||||
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
@ -232,7 +229,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str
|
||||
if itemID == peerKey {
|
||||
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
|
||||
if err := am.Store.SaveAccount(account); err != nil {
|
||||
return status.Errorf(codes.Internal, "can't save account")
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -248,12 +245,12 @@ func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*P
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.NotFound, "group with ID %s not found", groupID)
|
||||
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||
}
|
||||
|
||||
peers := make([]*Peer, 0, len(account.Groups))
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
internalStatus "github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
@ -200,10 +201,6 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
|
||||
return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err)
|
||||
}
|
||||
claims := jwtclaims.ExtractClaimsWithToken(token, s.config.HttpConfig.AuthAudience)
|
||||
_, err = s.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
||||
}
|
||||
userID = claims.UserId
|
||||
} else {
|
||||
log.Debugln("using setup key to register peer")
|
||||
@ -237,14 +234,12 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
if e, ok := FromError(err); ok {
|
||||
if e, ok := internalStatus.FromError(err); ok {
|
||||
switch e.Type() {
|
||||
case PreconditionFailed:
|
||||
return nil, status.Errorf(codes.FailedPrecondition, e.message)
|
||||
case AccountNotFound:
|
||||
case SetupKeyNotFound:
|
||||
case UserNotFound:
|
||||
return nil, status.Errorf(codes.NotFound, e.message)
|
||||
case internalStatus.PreconditionFailed:
|
||||
return nil, status.Errorf(codes.FailedPrecondition, e.Message)
|
||||
case internalStatus.NotFound:
|
||||
return nil, status.Errorf(codes.NotFound, e.Message)
|
||||
default:
|
||||
}
|
||||
}
|
||||
@ -301,7 +296,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
|
||||
peer, err := s.accountManager.GetPeer(peerKey.String())
|
||||
if err != nil {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.NotFound {
|
||||
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
|
||||
// peer doesn't exist -> check if setup key was provided
|
||||
if loginReq.GetJwtToken() == "" && loginReq.GetSetupKey() == "" {
|
||||
// absent setup key or jwt -> permission denied
|
||||
@ -387,7 +382,6 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
|
||||
case TCP:
|
||||
return proto.HostConfig_TCP
|
||||
default:
|
||||
// mbragin: todo something better?
|
||||
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
|
||||
}
|
||||
}
|
||||
|
@ -2,10 +2,9 @@ package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
@ -33,7 +32,8 @@ func NewGroups(accountManager server.AccountManager, authAudience string) *Group
|
||||
|
||||
// GetAllGroupsHandler list for the account
|
||||
func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
@ -45,52 +45,54 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
groups = append(groups, toGroupResponse(account, g))
|
||||
}
|
||||
|
||||
writeJSONObject(w, groups)
|
||||
util.WriteJSONObject(w, groups)
|
||||
}
|
||||
|
||||
// UpdateGroupHandler handles update to a group identified by a given ID
|
||||
func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
groupID, ok := vars["id"]
|
||||
if !ok {
|
||||
http.Error(w, "group ID field is missing", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
|
||||
return
|
||||
}
|
||||
if len(groupID) == 0 {
|
||||
http.Error(w, "group ID can't be empty", http.StatusUnprocessableEntity)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "group ID can't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
_, ok = account.Groups[groupID]
|
||||
if !ok {
|
||||
http.Error(w, fmt.Sprintf("couldn't find group with ID %s", groupID), http.StatusNotFound)
|
||||
util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
if allGroup.ID == groupID {
|
||||
http.Error(w, "updating group ALL is not allowed", http.StatusMethodNotAllowed)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PutApiGroupsIdJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if *req.Name == "" {
|
||||
http.Error(w, "group name shouldn't be empty", http.StatusUnprocessableEntity)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -102,53 +104,55 @@ func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if err := h.accountManager.SaveGroup(account.Id, &group); err != nil {
|
||||
log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONObject(w, toGroupResponse(account, &group))
|
||||
util.WriteJSONObject(w, toGroupResponse(account, &group))
|
||||
}
|
||||
|
||||
// PatchGroupHandler handles patch updates to a group identified by a given ID
|
||||
func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
groupID := vars["id"]
|
||||
if len(groupID) == 0 {
|
||||
http.Error(w, "invalid group Id", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
_, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
http.Error(w, fmt.Sprintf("couldn't find group id %s", groupID), http.StatusNotFound)
|
||||
util.WriteError(status.Errorf(status.NotFound, "couldn't find group ID %s", groupID), w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if allGroup.ID == groupID {
|
||||
http.Error(w, "updating group ALL is not allowed", http.StatusMethodNotAllowed)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PatchApiGroupsIdJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req) == 0 {
|
||||
http.Error(w, "no patch instruction received", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -158,13 +162,13 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
switch patch.Path {
|
||||
case api.GroupPatchOperationPathName:
|
||||
if patch.Op != api.GroupPatchOperationOpReplace {
|
||||
http.Error(w, fmt.Sprintf("Name field only accepts replace operation, got %s", patch.Op),
|
||||
http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"name field only accepts replace operation, got %s", patch.Op), w)
|
||||
return
|
||||
}
|
||||
|
||||
if len(patch.Value) == 0 || patch.Value[0] == "" {
|
||||
http.Error(w, "Group name shouldn't be empty", http.StatusUnprocessableEntity)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -193,53 +197,43 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
Values: peerKeys,
|
||||
})
|
||||
default:
|
||||
http.Error(w, "invalid operation, \"%s\", for Peers field", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"invalid operation, \"%v\", for Peers field", patch.Op), w)
|
||||
return
|
||||
}
|
||||
default:
|
||||
http.Error(w, "invalid patch path", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
group, err := h.accountManager.UpdateGroup(account.Id, groupID, operations)
|
||||
|
||||
if err != nil {
|
||||
errStatus, ok := status.FromError(err)
|
||||
if ok && errStatus.Code() == codes.Internal {
|
||||
http.Error(w, errStatus.String(), http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if ok && errStatus.Code() == codes.NotFound {
|
||||
http.Error(w, errStatus.String(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONObject(w, toGroupResponse(account, group))
|
||||
util.WriteJSONObject(w, toGroupResponse(account, group))
|
||||
}
|
||||
|
||||
// CreateGroupHandler handles group creation request
|
||||
func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PostApiGroupsJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
http.Error(w, "Group name shouldn't be empty", http.StatusUnprocessableEntity)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -249,55 +243,57 @@ func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
Peers: peerIPsToKeys(account, req.Peers),
|
||||
}
|
||||
|
||||
if err := h.accountManager.SaveGroup(account.Id, &group); err != nil {
|
||||
log.Errorf("failed creating group \"%s\" under account %s %v", req.Name, account.Id, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
err = h.accountManager.SaveGroup(account.Id, &group)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONObject(w, toGroupResponse(account, &group))
|
||||
util.WriteJSONObject(w, toGroupResponse(account, &group))
|
||||
}
|
||||
|
||||
// DeleteGroupHandler handles group deletion request
|
||||
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
aID := account.Id
|
||||
|
||||
groupID := mux.Vars(r)["id"]
|
||||
if len(groupID) == 0 {
|
||||
http.Error(w, "invalid group ID", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if allGroup.ID == groupID {
|
||||
http.Error(w, "deleting group ALL is not allowed", http.StatusMethodNotAllowed)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.accountManager.DeleteGroup(aID, groupID); err != nil {
|
||||
log.Errorf("failed delete group %s under account %s %v", groupID, aID, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
err = h.accountManager.DeleteGroup(aID, groupID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONObject(w, "")
|
||||
util.WriteJSONObject(w, "")
|
||||
}
|
||||
|
||||
// GetGroupHandler returns a group
|
||||
func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -305,19 +301,22 @@ func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
case http.MethodGet:
|
||||
groupID := mux.Vars(r)["id"]
|
||||
if len(groupID) == 0 {
|
||||
http.Error(w, "invalid group ID", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.accountManager.GetGroup(account.Id, groupID)
|
||||
if err != nil {
|
||||
http.Error(w, "group not found", http.StatusNotFound)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONObject(w, toGroupResponse(account, group))
|
||||
util.WriteJSONObject(w, toGroupResponse(account, group))
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.NotFound, "HTTP method not found"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -36,7 +37,7 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
|
||||
},
|
||||
GetGroupFunc: func(_, groupID string) (*server.Group, error) {
|
||||
if groupID != "idofthegroup" {
|
||||
return nil, fmt.Errorf("not found")
|
||||
return nil, status.Errorf(status.NotFound, "not found")
|
||||
}
|
||||
return &server.Group{
|
||||
ID: "idofthegroup",
|
||||
@ -67,7 +68,7 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
|
||||
}
|
||||
return nil, fmt.Errorf("peer not found")
|
||||
},
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
@ -78,7 +79,7 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
|
||||
Groups: map[string]*server.Group{
|
||||
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
|
||||
"id-all": {ID: "id-all", Name: "All"}},
|
||||
}, nil
|
||||
}, user, nil
|
||||
},
|
||||
},
|
||||
authAudience: "",
|
||||
@ -223,7 +224,7 @@ func TestWriteGroup(t *testing.T) {
|
||||
requestPath: "/api/groups/id-all",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(`{"Name":"super"}`)),
|
||||
expectedStatus: http.StatusMethodNotAllowed,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
@ -244,7 +245,7 @@ func TestWriteGroup(t *testing.T) {
|
||||
requestPath: "/api/groups/id-existed",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(`[{"op":"insert","path":"name","value":[""]}]`)),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
|
@ -1,7 +1,8 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@ -33,14 +34,15 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
||||
|
||||
ok, err := a.isUserAdmin(jwtClaims)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("error get user from JWT: %v", err), http.StatusUnauthorized)
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if !ok {
|
||||
switch r.Method {
|
||||
|
||||
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
|
||||
http.Error(w, "user is not admin", http.StatusForbidden)
|
||||
util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -7,12 +7,12 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
//Jwks is a collection of JSONWebKeys obtained from Config.HttpServerConfig.AuthKeysLocation
|
||||
// Jwks is a collection of JSONWebKeys obtained from Config.HttpServerConfig.AuthKeysLocation
|
||||
type Jwks struct {
|
||||
Keys []JSONWebKeys `json:"keys"`
|
||||
}
|
||||
|
||||
//JSONWebKeys is a representation of a Jason Web Key
|
||||
// JSONWebKeys is a representation of a Jason Web Key
|
||||
type JSONWebKeys struct {
|
||||
Kty string `json:"kty"`
|
||||
Kid string `json:"kid"`
|
||||
@ -22,7 +22,7 @@ type JSONWebKeys struct {
|
||||
X5c []string `json:"x5c"`
|
||||
}
|
||||
|
||||
//NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header
|
||||
// NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header
|
||||
func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) {
|
||||
|
||||
keys, err := getPemKeys(keysLocation)
|
||||
@ -66,7 +66,6 @@ func getPemKeys(keysLocation string) (*Jwks, error) {
|
||||
|
||||
var jwks = &Jwks{}
|
||||
err = json.NewDecoder(resp.Body).Decode(jwks)
|
||||
|
||||
if err != nil {
|
||||
return jwks, err
|
||||
}
|
||||
|
@ -5,6 +5,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
@ -57,7 +59,7 @@ type JWTMiddleware struct {
|
||||
}
|
||||
|
||||
func OnError(w http.ResponseWriter, r *http.Request, err string) {
|
||||
http.Error(w, err, http.StatusUnauthorized)
|
||||
util.WriteError(status.Errorf(status.Unauthorized, ""), w)
|
||||
}
|
||||
|
||||
// New constructs a new Secure instance with supplied options.
|
||||
|
@ -7,7 +7,9 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/http"
|
||||
)
|
||||
@ -30,7 +32,8 @@ func NewNameservers(accountManager server.AccountManager, authAudience string) *
|
||||
|
||||
// GetAllNameserversHandler returns the list of nameserver groups for the account
|
||||
func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
@ -39,7 +42,7 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re
|
||||
|
||||
nsGroups, err := h.accountManager.ListNameServerGroups(account.Id)
|
||||
if err != nil {
|
||||
toHTTPError(err, w)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -48,64 +51,67 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re
|
||||
apiNameservers = append(apiNameservers, toNameserverGroupResponse(r))
|
||||
}
|
||||
|
||||
writeJSONObject(w, apiNameservers)
|
||||
util.WriteJSONObject(w, apiNameservers)
|
||||
}
|
||||
|
||||
// CreateNameserverGroupHandler handles nameserver group creation request
|
||||
func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PostApiDnsNameserversJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
nsList, err := toServerNSList(req.Nameservers)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled)
|
||||
if err != nil {
|
||||
toHTTPError(err, w)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toNameserverGroupResponse(nsGroup)
|
||||
|
||||
writeJSONObject(w, &resp)
|
||||
util.WriteJSONObject(w, &resp)
|
||||
}
|
||||
|
||||
// UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID
|
||||
func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroupID := mux.Vars(r)["id"]
|
||||
if len(nsGroupID) == 0 {
|
||||
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PutApiDnsNameserversIdJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
nsList, err := toServerNSList(req.Nameservers)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -122,41 +128,42 @@ func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *htt
|
||||
|
||||
err = h.accountManager.SaveNameServerGroup(account.Id, updatedNSGroup)
|
||||
if err != nil {
|
||||
toHTTPError(err, w)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toNameserverGroupResponse(updatedNSGroup)
|
||||
|
||||
writeJSONObject(w, &resp)
|
||||
util.WriteJSONObject(w, &resp)
|
||||
}
|
||||
|
||||
// PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID
|
||||
func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroupID := mux.Vars(r)["id"]
|
||||
if len(nsGroupID) == 0 {
|
||||
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PatchApiDnsNameserversIdJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
var operations []server.NameServerGroupUpdateOperation
|
||||
for _, patch := range req {
|
||||
if patch.Op != api.NameserverGroupPatchOperationOpReplace {
|
||||
http.Error(w, fmt.Sprintf("nameserver groups only accepts replace operations, got %s", patch.Op),
|
||||
http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"nameserver groups only accepts replace operations, got %s", patch.Op), w)
|
||||
return
|
||||
}
|
||||
switch patch.Path {
|
||||
@ -196,49 +203,50 @@ func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http
|
||||
Values: patch.Value,
|
||||
})
|
||||
default:
|
||||
http.Error(w, "invalid patch path", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
updatedNSGroup, err := h.accountManager.UpdateNameServerGroup(account.Id, nsGroupID, operations)
|
||||
if err != nil {
|
||||
toHTTPError(err, w)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toNameserverGroupResponse(updatedNSGroup)
|
||||
|
||||
writeJSONObject(w, &resp)
|
||||
util.WriteJSONObject(w, &resp)
|
||||
}
|
||||
|
||||
// DeleteNameserverGroupHandler handles nameserver group deletion request
|
||||
func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroupID := mux.Vars(r)["id"]
|
||||
if len(nsGroupID) == 0 {
|
||||
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID)
|
||||
if err != nil {
|
||||
toHTTPError(err, w)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONObject(w, "")
|
||||
util.WriteJSONObject(w, "")
|
||||
}
|
||||
|
||||
// GetNameserverGroupHandler handles a nameserver group Get request identified by ID
|
||||
func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
@ -247,19 +255,19 @@ func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.R
|
||||
|
||||
nsGroupID := mux.Vars(r)["id"]
|
||||
if len(nsGroupID) == 0 {
|
||||
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, nsGroupID)
|
||||
if err != nil {
|
||||
toHTTPError(err, w)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toNameserverGroupResponse(nsGroup)
|
||||
|
||||
writeJSONObject(w, &resp)
|
||||
util.WriteJSONObject(w, &resp)
|
||||
|
||||
}
|
||||
|
||||
|
@ -5,9 +5,8 @@ import (
|
||||
"encoding/json"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -62,7 +61,7 @@ func initNameserversTestData() *Nameservers {
|
||||
if nsGroupID == existingNSGroupID {
|
||||
return baseExistingNSGroup.Copy(), nil
|
||||
}
|
||||
return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID)
|
||||
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
|
||||
},
|
||||
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
|
||||
return &nbdns.NameServerGroup{
|
||||
@ -83,12 +82,12 @@ func initNameserversTestData() *Nameservers {
|
||||
if nsGroupToSave.ID == existingNSGroupID {
|
||||
return nil
|
||||
}
|
||||
return status.Errorf(codes.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
|
||||
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
|
||||
},
|
||||
UpdateNameServerGroupFunc: func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
|
||||
nsGroupToUpdate := baseExistingNSGroup.Copy()
|
||||
if nsGroupID != nsGroupToUpdate.ID {
|
||||
return nil, status.Errorf(codes.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
|
||||
return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
|
||||
}
|
||||
for _, operation := range operations {
|
||||
switch operation.Type {
|
||||
@ -110,8 +109,8 @@ func initNameserversTestData() *Nameservers {
|
||||
}
|
||||
return nsGroupToUpdate, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
return testingNSAccount, nil
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testingNSAccount, nil, nil
|
||||
},
|
||||
},
|
||||
authAudience: "",
|
||||
@ -181,7 +180,7 @@ func TestNameserversHandlers(t *testing.T) {
|
||||
requestPath: "/api/dns/nameservers",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
@ -223,7 +222,7 @@ func TestNameserversHandlers(t *testing.T) {
|
||||
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
|
@ -6,8 +6,9 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@ -28,50 +29,47 @@ func NewPeers(accountManager server.AccountManager, authAudience string) *Peers
|
||||
|
||||
func (h *Peers) updatePeer(account *server.Account, peer *server.Peer, w http.ResponseWriter, r *http.Request) {
|
||||
req := &api.PutApiPeersIdJSONBody{}
|
||||
peerIp := peer.IP
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
update := &server.Peer{Key: peer.Key, SSHEnabled: req.SshEnabled, Name: req.Name}
|
||||
peer, err = h.accountManager.UpdatePeer(account.Id, update)
|
||||
if err != nil {
|
||||
log.Errorf("failed updating peer %s under account %s %v", peerIp, account.Id, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
writeJSONObject(w, toPeerResponse(peer, account))
|
||||
util.WriteJSONObject(w, toPeerResponse(peer, account))
|
||||
}
|
||||
|
||||
func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseWriter, r *http.Request) {
|
||||
_, err := h.accountManager.DeletePeer(accountId, peer.Key)
|
||||
if err != nil {
|
||||
log.Errorf("failed deleteing peer %s, %v", peer.IP, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
writeJSONObject(w, "")
|
||||
util.WriteJSONObject(w, "")
|
||||
}
|
||||
|
||||
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
vars := mux.Vars(r)
|
||||
peerId := vars["id"] //effectively peer IP address
|
||||
if len(peerId) == 0 {
|
||||
http.Error(w, "invalid peer Id", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
peer, err := h.accountManager.GetPeerByIP(account.Id, peerId)
|
||||
if err != nil {
|
||||
http.Error(w, "peer not found", http.StatusNotFound)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -83,11 +81,11 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
h.updatePeer(account, peer, w, r)
|
||||
return
|
||||
case http.MethodGet:
|
||||
writeJSONObject(w, toPeerResponse(peer, account))
|
||||
util.WriteJSONObject(w, toPeerResponse(peer, account))
|
||||
return
|
||||
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
}
|
||||
|
||||
}
|
||||
@ -95,15 +93,16 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
peers, err := h.accountManager.GetPeers(account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -111,10 +110,10 @@ func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
|
||||
for _, peer := range peers {
|
||||
respBody = append(respBody, toPeerResponse(peer, account))
|
||||
}
|
||||
writeJSONObject(w, respBody)
|
||||
util.WriteJSONObject(w, respBody)
|
||||
return
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -22,7 +22,8 @@ func initTestMetaData(peers ...*server.Peer) *Peers {
|
||||
GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) {
|
||||
return peers, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
@ -30,9 +31,9 @@ func initTestMetaData(peers ...*server.Peer) *Peers {
|
||||
"test_peer": peers[0],
|
||||
},
|
||||
Users: map[string]*server.User{
|
||||
"test_user": server.NewAdminUser("test_user"),
|
||||
"test_user": user,
|
||||
},
|
||||
}, nil
|
||||
}, user, nil
|
||||
},
|
||||
},
|
||||
authAudience: "",
|
||||
|
@ -2,15 +2,13 @@ package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"net/http"
|
||||
"unicode/utf8"
|
||||
)
|
||||
@ -33,25 +31,16 @@ func NewRoutes(accountManager server.AccountManager, authAudience string) *Route
|
||||
|
||||
// GetAllRoutesHandler returns the list of routes for the account
|
||||
func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routes, err := h.accountManager.ListRoutes(account.Id, user.Id)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
if e, ok := server.FromError(err); ok {
|
||||
switch e.Type() {
|
||||
case server.PermissionDenied:
|
||||
http.Error(w, e.Error(), http.StatusForbidden)
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
apiRoutes := make([]*api.Route, 0)
|
||||
@ -59,20 +48,22 @@ func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
|
||||
apiRoutes = append(apiRoutes, toRouteResponse(account, r))
|
||||
}
|
||||
|
||||
writeJSONObject(w, apiRoutes)
|
||||
util.WriteJSONObject(w, apiRoutes)
|
||||
}
|
||||
|
||||
// CreateRouteHandler handles route creation request
|
||||
func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PostApiRoutesJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -80,8 +71,7 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if req.Peer != "" {
|
||||
peer, err := h.accountManager.GetPeerByIP(account.Id, req.Peer)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusUnprocessableEntity)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
peerKey = peer.Key
|
||||
@ -89,57 +79,60 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
_, newPrefix, err := route.ParseNetwork(req.Network)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("couldn't parse update prefix %s", req.Network), http.StatusBadRequest)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
|
||||
http.Error(w, fmt.Sprintf("identifier should be between 1 and %d", route.MaxNetIDChar), http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d",
|
||||
route.MaxNetIDChar), w)
|
||||
return
|
||||
}
|
||||
|
||||
newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix.String(), peerKey, req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Enabled)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toRouteResponse(account, newRoute)
|
||||
|
||||
writeJSONObject(w, &resp)
|
||||
util.WriteJSONObject(w, &resp)
|
||||
}
|
||||
|
||||
// UpdateRouteHandler handles update to a route identified by a given ID
|
||||
func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
routeID := vars["id"]
|
||||
if len(routeID) == 0 {
|
||||
http.Error(w, "invalid route Id", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("couldn't find route for ID %s", routeID), http.StatusNotFound)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PutApiRoutesIdJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
prefixType, newPrefix, err := route.ParseNetwork(req.Network)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("couldn't parse update prefix %s for route ID %s", req.Network, routeID), http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "couldn't parse update prefix %s for route ID %s",
|
||||
req.Network, routeID), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -147,15 +140,15 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if req.Peer != "" {
|
||||
peer, err := h.accountManager.GetPeerByIP(account.Id, req.Peer)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusUnprocessableEntity)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
peerKey = peer.Key
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
|
||||
http.Error(w, fmt.Sprintf("identifier should be between 1 and %d", route.MaxNetIDChar), http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"identifier should be between 1 and %d", route.MaxNetIDChar), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -173,46 +166,46 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = h.accountManager.SaveRoute(account.Id, newRoute)
|
||||
if err != nil {
|
||||
log.Errorf("failed updating route \"%s\" under account %s %v", routeID, account.Id, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toRouteResponse(account, newRoute)
|
||||
|
||||
writeJSONObject(w, &resp)
|
||||
util.WriteJSONObject(w, &resp)
|
||||
}
|
||||
|
||||
// PatchRouteHandler handles patch updates to a route identified by a given ID
|
||||
func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
routeID := vars["id"]
|
||||
if len(routeID) == 0 {
|
||||
http.Error(w, "invalid route ID", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Error(w, fmt.Sprintf("couldn't find route ID %s", routeID), http.StatusNotFound)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PatchApiRoutesIdJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req) == 0 {
|
||||
http.Error(w, "no patch instruction received", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -222,8 +215,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
switch patch.Path {
|
||||
case api.RoutePatchOperationPathNetwork:
|
||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
||||
http.Error(w, fmt.Sprintf("Network field only accepts replace operation, got %s", patch.Op),
|
||||
http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"network field only accepts replace operation, got %s", patch.Op), w)
|
||||
return
|
||||
}
|
||||
operations = append(operations, server.RouteUpdateOperation{
|
||||
@ -232,8 +225,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
case api.RoutePatchOperationPathDescription:
|
||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
||||
http.Error(w, fmt.Sprintf("Description field only accepts replace operation, got %s", patch.Op),
|
||||
http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"description field only accepts replace operation, got %s", patch.Op), w)
|
||||
return
|
||||
}
|
||||
operations = append(operations, server.RouteUpdateOperation{
|
||||
@ -242,8 +235,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
case api.RoutePatchOperationPathNetworkId:
|
||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
||||
http.Error(w, fmt.Sprintf("Network Identifier field only accepts replace operation, got %s", patch.Op),
|
||||
http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"network Identifier field only accepts replace operation, got %s", patch.Op), w)
|
||||
return
|
||||
}
|
||||
operations = append(operations, server.RouteUpdateOperation{
|
||||
@ -252,21 +245,20 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
case api.RoutePatchOperationPathPeer:
|
||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
||||
http.Error(w, fmt.Sprintf("Peer field only accepts replace operation, got %s", patch.Op),
|
||||
http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"peer field only accepts replace operation, got %s", patch.Op), w)
|
||||
return
|
||||
}
|
||||
if len(patch.Value) > 1 {
|
||||
http.Error(w, fmt.Sprintf("Value field only accepts 1 value, got %d", len(patch.Value)),
|
||||
http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"value field only accepts 1 value, got %d", len(patch.Value)), w)
|
||||
return
|
||||
}
|
||||
peerValue := patch.Value
|
||||
if patch.Value[0] != "" {
|
||||
peer, err := h.accountManager.GetPeerByIP(account.Id, patch.Value[0])
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusUnprocessableEntity)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
peerValue = []string{peer.Key}
|
||||
@ -277,8 +269,9 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
case api.RoutePatchOperationPathMetric:
|
||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
||||
http.Error(w, fmt.Sprintf("Metric field only accepts replace operation, got %s", patch.Op),
|
||||
http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"metric field only accepts replace operation, got %s", patch.Op), w)
|
||||
|
||||
return
|
||||
}
|
||||
operations = append(operations, server.RouteUpdateOperation{
|
||||
@ -287,8 +280,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
case api.RoutePatchOperationPathMasquerade:
|
||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
||||
http.Error(w, fmt.Sprintf("Masquerade field only accepts replace operation, got %s", patch.Op),
|
||||
http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"masquerade field only accepts replace operation, got %s", patch.Op), w)
|
||||
return
|
||||
}
|
||||
operations = append(operations, server.RouteUpdateOperation{
|
||||
@ -297,8 +290,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
case api.RoutePatchOperationPathEnabled:
|
||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
||||
http.Error(w, fmt.Sprintf("Enabled field only accepts replace operation, got %s", patch.Op),
|
||||
http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"enabled field only accepts replace operation, got %s", patch.Op), w)
|
||||
return
|
||||
}
|
||||
operations = append(operations, server.RouteUpdateOperation{
|
||||
@ -306,90 +299,68 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
Values: patch.Value,
|
||||
})
|
||||
default:
|
||||
http.Error(w, "invalid patch path", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
route, err := h.accountManager.UpdateRoute(account.Id, routeID, operations)
|
||||
|
||||
if err != nil {
|
||||
errStatus, ok := status.FromError(err)
|
||||
if ok && errStatus.Code() == codes.Internal {
|
||||
http.Error(w, errStatus.String(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if ok && errStatus.Code() == codes.NotFound {
|
||||
http.Error(w, errStatus.String(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if ok && errStatus.Code() == codes.InvalidArgument {
|
||||
http.Error(w, errStatus.String(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
log.Errorf("failed updating route %s under account %s %v", routeID, account.Id, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toRouteResponse(account, route)
|
||||
|
||||
writeJSONObject(w, &resp)
|
||||
util.WriteJSONObject(w, &resp)
|
||||
}
|
||||
|
||||
// DeleteRouteHandler handles route deletion request
|
||||
func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routeID := mux.Vars(r)["id"]
|
||||
if len(routeID) == 0 {
|
||||
http.Error(w, "invalid route ID", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteRoute(account.Id, routeID)
|
||||
if err != nil {
|
||||
errStatus, ok := status.FromError(err)
|
||||
if ok && errStatus.Code() == codes.NotFound {
|
||||
http.Error(w, fmt.Sprintf("route %s not found under account %s", routeID, account.Id), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
log.Errorf("failed delete route %s under account %s %v", routeID, account.Id, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONObject(w, "")
|
||||
util.WriteJSONObject(w, "")
|
||||
}
|
||||
|
||||
// GetRouteHandler handles a route Get request identified by ID
|
||||
func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routeID := mux.Vars(r)["id"]
|
||||
if len(routeID) == 0 {
|
||||
http.Error(w, "invalid route ID", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id)
|
||||
if err != nil {
|
||||
http.Error(w, "route not found", http.StatusNotFound)
|
||||
util.WriteError(status.Errorf(status.NotFound, "route not found"), w)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONObject(w, toRouteResponse(account, foundRoute))
|
||||
util.WriteJSONObject(w, toRouteResponse(account, foundRoute))
|
||||
}
|
||||
|
||||
func toRouteResponse(account *server.Account, serverRoute *route.Route) *api.Route {
|
||||
|
@ -5,9 +5,8 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -63,7 +62,7 @@ func initRoutesTestData() *Routes {
|
||||
if routeID == existingRouteID {
|
||||
return baseExistingRoute, nil
|
||||
}
|
||||
return nil, status.Errorf(codes.NotFound, "route with ID %s not found", routeID)
|
||||
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
||||
},
|
||||
CreateRouteFunc: func(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) {
|
||||
networkType, p, _ := route.ParseNetwork(network)
|
||||
@ -83,13 +82,13 @@ func initRoutesTestData() *Routes {
|
||||
},
|
||||
DeleteRouteFunc: func(_ string, peerIP string) error {
|
||||
if peerIP != existingRouteID {
|
||||
return status.Errorf(codes.NotFound, "Peer with ID %s not found", peerIP)
|
||||
return status.Errorf(status.NotFound, "Peer with ID %s not found", peerIP)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
|
||||
if peerIP != existingPeerID {
|
||||
return nil, status.Errorf(codes.NotFound, "Peer with ID %s not found", peerIP)
|
||||
return nil, status.Errorf(status.NotFound, "Peer with ID %s not found", peerIP)
|
||||
}
|
||||
return &server.Peer{
|
||||
Key: existingPeerKey,
|
||||
@ -99,7 +98,7 @@ func initRoutesTestData() *Routes {
|
||||
UpdateRouteFunc: func(_ string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) {
|
||||
routeToUpdate := baseExistingRoute
|
||||
if routeID != routeToUpdate.ID {
|
||||
return nil, status.Errorf(codes.NotFound, "route %s no longer exists", routeID)
|
||||
return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID)
|
||||
}
|
||||
for _, operation := range operations {
|
||||
switch operation.Type {
|
||||
@ -123,8 +122,8 @@ func initRoutesTestData() *Routes {
|
||||
}
|
||||
return routeToUpdate, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
return testingAccount, nil
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testingAccount, testingAccount.Users["test_user"], nil
|
||||
},
|
||||
},
|
||||
authAudience: "",
|
||||
@ -201,15 +200,15 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/routes",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", notFoundPeerID)),
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
name: "POST Not Invalid Network Identifier",
|
||||
name: "POST Invalid Network Identifier",
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/routes",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"12345678901234567890qwertyuiopqwertyuiop1\",\"Peer\":\"%s\"}", existingPeerID)),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
@ -217,7 +216,7 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/routes",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/34\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", existingPeerID)),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
@ -251,7 +250,7 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/routes/" + existingRouteID,
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", notFoundPeerID)),
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
@ -259,7 +258,7 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/routes/" + existingRouteID,
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"12345678901234567890qwertyuiopqwertyuiop1\",\"Peer\":\"%s\"}", existingPeerID)),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
@ -267,7 +266,7 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/routes/" + existingRouteID,
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/34\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", existingPeerID)),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
@ -312,7 +311,7 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
requestType: http.MethodPatch,
|
||||
requestPath: "/api/routes/" + existingRouteID,
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf("[{\"op\":\"replace\",\"path\":\"peer\",\"value\":[\"%s\"]}]", notFoundPeerID)),
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
|
@ -2,15 +2,13 @@ package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@ -31,25 +29,16 @@ func NewRules(accountManager server.AccountManager, authAudience string) *Rules
|
||||
|
||||
// GetAllRulesHandler list for the account
|
||||
func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountRules, err := h.accountManager.ListRules(account.Id, user.Id)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
if e, ok := server.FromError(err); ok {
|
||||
switch e.Type() {
|
||||
case server.PermissionDenied:
|
||||
http.Error(w, e.Error(), http.StatusForbidden)
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
rules := []*api.Rule{}
|
||||
@ -57,38 +46,39 @@ func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
|
||||
rules = append(rules, toRuleResponse(account, r))
|
||||
}
|
||||
|
||||
writeJSONObject(w, rules)
|
||||
util.WriteJSONObject(w, rules)
|
||||
}
|
||||
|
||||
// UpdateRuleHandler handles update to a rule identified by a given ID
|
||||
func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
ruleID := vars["id"]
|
||||
if len(ruleID) == 0 {
|
||||
http.Error(w, "invalid rule Id", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
_, ok := account.Rules[ruleID]
|
||||
if !ok {
|
||||
http.Error(w, fmt.Sprintf("couldn't find rule id %s", ruleID), http.StatusNotFound)
|
||||
util.WriteError(status.Errorf(status.NotFound, "couldn't find rule id %s", ruleID), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PutApiRulesIdJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
http.Error(w, "Rule name shouldn't be empty", http.StatusUnprocessableEntity)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -115,50 +105,52 @@ func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
case server.TrafficFlowBidirectString:
|
||||
rule.Flow = server.TrafficFlowBidirect
|
||||
default:
|
||||
http.Error(w, "unknown flow type", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "unknown flow type"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.accountManager.SaveRule(account.Id, &rule); err != nil {
|
||||
log.Errorf("failed updating rule \"%s\" under account %s %v", ruleID, account.Id, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
err = h.accountManager.SaveRule(account.Id, &rule)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toRuleResponse(account, &rule)
|
||||
|
||||
writeJSONObject(w, &resp)
|
||||
util.WriteJSONObject(w, &resp)
|
||||
}
|
||||
|
||||
// PatchRuleHandler handles patch updates to a rule identified by a given ID
|
||||
func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
ruleID := vars["id"]
|
||||
if len(ruleID) == 0 {
|
||||
http.Error(w, "invalid rule Id", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
_, ok := account.Rules[ruleID]
|
||||
if !ok {
|
||||
http.Error(w, fmt.Sprintf("couldn't find rule id %s", ruleID), http.StatusNotFound)
|
||||
util.WriteError(status.Errorf(status.NotFound, "couldn't find rule ID %s", ruleID), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PatchApiRulesIdJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req) == 0 {
|
||||
http.Error(w, "no patch instruction received", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -168,12 +160,12 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
switch patch.Path {
|
||||
case api.RulePatchOperationPathName:
|
||||
if patch.Op != api.RulePatchOperationOpReplace {
|
||||
http.Error(w, fmt.Sprintf("Name field only accepts replace operation, got %s", patch.Op),
|
||||
http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"name field only accepts replace operation, got %s", patch.Op), w)
|
||||
return
|
||||
}
|
||||
if len(patch.Value) == 0 || patch.Value[0] == "" {
|
||||
http.Error(w, "Rule name shouldn't be empty", http.StatusUnprocessableEntity)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w)
|
||||
return
|
||||
}
|
||||
operations = append(operations, server.RuleUpdateOperation{
|
||||
@ -182,8 +174,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
case api.RulePatchOperationPathDescription:
|
||||
if patch.Op != api.RulePatchOperationOpReplace {
|
||||
http.Error(w, fmt.Sprintf("Description field only accepts replace operation, got %s", patch.Op),
|
||||
http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"description field only accepts replace operation, got %s", patch.Op), w)
|
||||
return
|
||||
}
|
||||
operations = append(operations, server.RuleUpdateOperation{
|
||||
@ -192,8 +184,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
case api.RulePatchOperationPathFlow:
|
||||
if patch.Op != api.RulePatchOperationOpReplace {
|
||||
http.Error(w, fmt.Sprintf("Flow field only accepts replace operation, got %s", patch.Op),
|
||||
http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"flow field only accepts replace operation, got %s", patch.Op), w)
|
||||
return
|
||||
}
|
||||
operations = append(operations, server.RuleUpdateOperation{
|
||||
@ -202,8 +194,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
case api.RulePatchOperationPathDisabled:
|
||||
if patch.Op != api.RulePatchOperationOpReplace {
|
||||
http.Error(w, fmt.Sprintf("Disabled field only accepts replace operation, got %s", patch.Op),
|
||||
http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"disabled field only accepts replace operation, got %s", patch.Op), w)
|
||||
return
|
||||
}
|
||||
operations = append(operations, server.RuleUpdateOperation{
|
||||
@ -228,7 +220,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
Values: patch.Value,
|
||||
})
|
||||
default:
|
||||
http.Error(w, "invalid operation, \"%s\", for Source field", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"invalid operation \"%s\" on Source field", patch.Op), w)
|
||||
return
|
||||
}
|
||||
case api.RulePatchOperationPathDestinations:
|
||||
@ -249,11 +242,12 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
Values: patch.Value,
|
||||
})
|
||||
default:
|
||||
http.Error(w, "invalid operation, \"%s\", for Destination field", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||
"invalid operation \"%s\" on Destination field", patch.Op), w)
|
||||
return
|
||||
}
|
||||
default:
|
||||
http.Error(w, "invalid patch path", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -261,48 +255,33 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
rule, err := h.accountManager.UpdateRule(account.Id, ruleID, operations)
|
||||
|
||||
if err != nil {
|
||||
errStatus, ok := status.FromError(err)
|
||||
if ok && errStatus.Code() == codes.Internal {
|
||||
http.Error(w, errStatus.String(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if ok && errStatus.Code() == codes.NotFound {
|
||||
http.Error(w, errStatus.String(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if ok && errStatus.Code() == codes.InvalidArgument {
|
||||
http.Error(w, errStatus.String(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
log.Errorf("failed updating rule %s under account %s %v", ruleID, account.Id, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toRuleResponse(account, rule)
|
||||
|
||||
writeJSONObject(w, &resp)
|
||||
util.WriteJSONObject(w, &resp)
|
||||
}
|
||||
|
||||
// CreateRuleHandler handles rule creation request
|
||||
func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PostApiRulesJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
http.Error(w, "Rule name shouldn't be empty", http.StatusUnprocessableEntity)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -329,50 +308,52 @@ func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
case server.TrafficFlowBidirectString:
|
||||
rule.Flow = server.TrafficFlowBidirect
|
||||
default:
|
||||
http.Error(w, "unknown flow type", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "unknown flow type"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.accountManager.SaveRule(account.Id, &rule); err != nil {
|
||||
log.Errorf("failed creating rule \"%s\" under account %s %v", req.Name, account.Id, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
err = h.accountManager.SaveRule(account.Id, &rule)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toRuleResponse(account, &rule)
|
||||
|
||||
writeJSONObject(w, &resp)
|
||||
util.WriteJSONObject(w, &resp)
|
||||
}
|
||||
|
||||
// DeleteRuleHandler handles rule deletion request
|
||||
func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
aID := account.Id
|
||||
|
||||
rID := mux.Vars(r)["id"]
|
||||
if len(rID) == 0 {
|
||||
http.Error(w, "invalid rule ID", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.accountManager.DeleteRule(aID, rID); err != nil {
|
||||
log.Errorf("failed delete rule %s under account %s %v", rID, aID, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
err = h.accountManager.DeleteRule(aID, rID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONObject(w, "")
|
||||
util.WriteJSONObject(w, "")
|
||||
}
|
||||
|
||||
// GetRuleHandler handles a group Get request identified by ID
|
||||
func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -380,19 +361,19 @@ func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
case http.MethodGet:
|
||||
ruleID := mux.Vars(r)["id"]
|
||||
if len(ruleID) == 0 {
|
||||
http.Error(w, "invalid rule ID", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
rule, err := h.accountManager.GetRule(account.Id, ruleID, user.Id)
|
||||
if err != nil {
|
||||
http.Error(w, "rule not found", http.StatusNotFound)
|
||||
util.WriteError(status.Errorf(status.NotFound, "rule not found"), w)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONObject(w, toRuleResponse(account, rule))
|
||||
util.WriteJSONObject(w, toRuleResponse(account, rule))
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
util.WriteError(status.Errorf(status.NotFound, "method not found"), w)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -66,7 +66,8 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
|
||||
}
|
||||
return &rule, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
@ -76,9 +77,9 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
|
||||
"G": {ID: "G"},
|
||||
},
|
||||
Users: map[string]*server.User{
|
||||
"test_user": server.NewAdminUser("test_user"),
|
||||
"test_user": user,
|
||||
},
|
||||
}, nil
|
||||
}, user, nil
|
||||
},
|
||||
},
|
||||
authAudience: "",
|
||||
@ -238,7 +239,7 @@ func TestRulesWriteRule(t *testing.T) {
|
||||
requestPath: "/api/rules/id-existed",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(`[{"op":"insert","path":"name","value":[""]}]`)),
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
|
@ -2,14 +2,12 @@ package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
@ -31,29 +29,28 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authAudience stri
|
||||
|
||||
// CreateSetupKeyHandler is a POST requests that creates a new SetupKey
|
||||
func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
req := &api.PostApiSetupKeysJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
http.Error(w, "Setup key name shouldn't be empty", http.StatusUnprocessableEntity)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable ||
|
||||
server.SetupKeyType(req.Type) == server.SetupKeyOneOff) {
|
||||
|
||||
http.Error(w, "unknown setup key type "+string(req.Type), http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "unknown setup key type %s", string(req.Type)), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -62,17 +59,11 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
|
||||
if req.AutoGroups == nil {
|
||||
req.AutoGroups = []string{}
|
||||
}
|
||||
// newExpiresIn := time.Duration(req.ExpiresIn) * time.Second
|
||||
// newKey.ExpiresAt = time.Now().Add(newExpiresIn)
|
||||
|
||||
setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
||||
req.AutoGroups)
|
||||
if err != nil {
|
||||
errStatus, ok := status.FromError(err)
|
||||
if ok && errStatus.Code() == codes.NotFound {
|
||||
http.Error(w, "account not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
http.Error(w, "failed adding setup key", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -81,29 +72,23 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
|
||||
|
||||
// GetSetupKeyHandler is a GET request to get a SetupKey by ID
|
||||
func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["id"]
|
||||
if len(keyID) == 0 {
|
||||
http.Error(w, "invalid key Id", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID)
|
||||
if err != nil {
|
||||
errStatus, ok := status.FromError(err)
|
||||
if ok && errStatus.Code() == codes.NotFound {
|
||||
http.Error(w, fmt.Sprintf("setup key %s not found under account %s", keyID, account.Id), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
log.Errorf("failed getting setup key %s under account %s %v", keyID, account.Id, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -112,34 +97,34 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// UpdateSetupKeyHandler is a PUT request to update server.SetupKey
|
||||
func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["id"]
|
||||
if len(keyID) == 0 {
|
||||
http.Error(w, "invalid key Id", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
req := &api.PutApiSetupKeysIdJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
http.Error(w, fmt.Sprintf("setup key name field is invalid: %s", req.Name), http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.AutoGroups == nil {
|
||||
http.Error(w, fmt.Sprintf("setup key AutoGroups field is invalid: %s", req.AutoGroups), http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -150,16 +135,8 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request
|
||||
newKey.Id = keyID
|
||||
|
||||
newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey)
|
||||
|
||||
if err != nil {
|
||||
if e, ok := status.FromError(err); ok {
|
||||
switch e.Code() {
|
||||
case codes.NotFound:
|
||||
http.Error(w, fmt.Sprintf("couldn't find setup key for ID %s", keyID), http.StatusNotFound)
|
||||
default:
|
||||
http.Error(w, "failed updating setup key", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
writeSuccess(w, newKey)
|
||||
@ -168,25 +145,25 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request
|
||||
// GetAllSetupKeysHandler is a GET request that returns a list of SetupKey
|
||||
func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
apiSetupKeys := make([]*api.SetupKey, 0)
|
||||
for _, key := range setupKeys {
|
||||
apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
|
||||
}
|
||||
|
||||
writeJSONObject(w, apiSetupKeys)
|
||||
util.WriteJSONObject(w, apiSetupKeys)
|
||||
}
|
||||
|
||||
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
|
||||
@ -194,7 +171,7 @@ func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(toResponseBody(key))
|
||||
if err != nil {
|
||||
http.Error(w, "failed handling request", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -6,9 +6,8 @@ import (
|
||||
"fmt"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -32,7 +31,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
user *server.User) *SetupKeys {
|
||||
return &SetupKeys{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return &server.Account{
|
||||
Id: testAccountID,
|
||||
Domain: "hotmail.com",
|
||||
@ -45,7 +44,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
Groups: map[string]*server.Group{
|
||||
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
|
||||
"id-all": {ID: "id-all", Name: "All"}},
|
||||
}, nil
|
||||
}, user, nil
|
||||
},
|
||||
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string) (*server.SetupKey, error) {
|
||||
if keyName == newKey.Name || typ != newKey.Type {
|
||||
@ -60,7 +59,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
case newKey.Id:
|
||||
return newKey, nil
|
||||
default:
|
||||
return nil, status.Errorf(codes.NotFound, "key %s not found", keyID)
|
||||
return nil, status.Errorf(status.NotFound, "key %s not found", keyID)
|
||||
}
|
||||
},
|
||||
|
||||
@ -68,7 +67,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
if key.Id == updatedSetupKey.Id {
|
||||
return updatedSetupKey, nil
|
||||
}
|
||||
return nil, status.Errorf(codes.NotFound, "key %s not found", key.Id)
|
||||
return nil, status.Errorf(status.NotFound, "key %s not found", key.Id)
|
||||
},
|
||||
|
||||
ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) {
|
||||
|
@ -2,12 +2,10 @@ package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
@ -31,33 +29,34 @@ func NewUserHandler(accountManager server.AccountManager, authAudience string) *
|
||||
// UpdateUser is a PUT requests to update User data
|
||||
func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPut {
|
||||
http.Error(w, "", http.StatusBadRequest)
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
userID := vars["id"]
|
||||
if len(userID) == 0 {
|
||||
http.Error(w, "invalid user ID", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
req := &api.PutApiUsersIdJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
userRole := server.StrRoleToUserRole(req.Role)
|
||||
if userRole == server.UserRoleUnknown {
|
||||
http.Error(w, "invalid user role", http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -67,40 +66,36 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
AutoGroups: req.AutoGroups,
|
||||
})
|
||||
if err != nil {
|
||||
if e, ok := status.FromError(err); ok {
|
||||
switch e.Code() {
|
||||
case codes.NotFound:
|
||||
http.Error(w, fmt.Sprintf("couldn't find a user for ID %s", userID), http.StatusNotFound)
|
||||
default:
|
||||
http.Error(w, "failed to update user", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
writeJSONObject(w, toUserResponse(newUser))
|
||||
util.WriteJSONObject(w, toUserResponse(newUser))
|
||||
|
||||
}
|
||||
|
||||
// CreateUserHandler creates a User in the system with a status "invited" (effectively this is a user invite).
|
||||
func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
req := &api.PostApiUsersJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown {
|
||||
http.Error(w, "unknown user role "+req.Role, http.StatusBadRequest)
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -111,37 +106,30 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request)
|
||||
AutoGroups: req.AutoGroups,
|
||||
})
|
||||
if err != nil {
|
||||
if e, ok := server.FromError(err); ok {
|
||||
switch e.Type() {
|
||||
case server.UserAlreadyExists:
|
||||
http.Error(w, "You can't invite users with an existing NetBird account.", http.StatusPreconditionFailed)
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
http.Error(w, "failed to invite", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
writeJSONObject(w, toUserResponse(newUser))
|
||||
util.WriteJSONObject(w, toUserResponse(newUser))
|
||||
}
|
||||
|
||||
// GetUsers returns a list of users of the account this user belongs to.
|
||||
// It also gathers additional user data (like email and name) from the IDP manager.
|
||||
func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "", http.StatusBadRequest)
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -150,7 +138,7 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
|
||||
users = append(users, toUserResponse(r))
|
||||
}
|
||||
|
||||
writeJSONObject(w, users)
|
||||
util.WriteJSONObject(w, users)
|
||||
}
|
||||
|
||||
func toUserResponse(user *server.UserInfo) *api.User {
|
||||
|
@ -16,7 +16,7 @@ import (
|
||||
func initUsers(user ...*server.User) *UserHandler {
|
||||
return &UserHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
users := make(map[string]*server.User, 0)
|
||||
for _, u := range user {
|
||||
users[u.Id] = u
|
||||
@ -25,7 +25,7 @@ func initUsers(user ...*server.User) *UserHandler {
|
||||
Id: "12345",
|
||||
Domain: "netbird.io",
|
||||
Users: users,
|
||||
}, nil
|
||||
}, users[claims.UserId], nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
|
||||
users := make([]*server.UserInfo, 0)
|
||||
@ -66,7 +66,6 @@ func TestGetUsers(t *testing.T) {
|
||||
expectedResult []*server.User
|
||||
}{
|
||||
{name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users/", expectedStatus: http.StatusOK, expectedResult: users},
|
||||
{name: "WrongRequestMethod", requestType: http.MethodPost, requestPath: "/api/users/", expectedStatus: http.StatusBadRequest},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
|
@ -1,97 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// writeJSONObject simply writes object to the HTTP reponse in JSON format
|
||||
func writeJSONObject(w http.ResponseWriter, obj interface{}) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
err := json.NewEncoder(w).Encode(obj)
|
||||
if err != nil {
|
||||
http.Error(w, "failed handling request", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Duration is used strictly for JSON requests/responses due to duration marshalling issues
|
||||
type Duration struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
func (d Duration) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(d.String())
|
||||
}
|
||||
|
||||
func (d *Duration) UnmarshalJSON(b []byte) error {
|
||||
var v interface{}
|
||||
if err := json.Unmarshal(b, &v); err != nil {
|
||||
return err
|
||||
}
|
||||
switch value := v.(type) {
|
||||
case float64:
|
||||
d.Duration = time.Duration(value)
|
||||
return nil
|
||||
case string:
|
||||
var err error
|
||||
d.Duration, err = time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return errors.New("invalid duration")
|
||||
}
|
||||
}
|
||||
|
||||
func getJWTAccount(accountManager server.AccountManager,
|
||||
jwtExtractor jwtclaims.ClaimsExtractor,
|
||||
authAudience string, r *http.Request) (*server.Account, *server.User, error) {
|
||||
|
||||
claims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience)
|
||||
|
||||
account, err := accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed getting account of a user %s: %v", claims.UserId, err)
|
||||
}
|
||||
|
||||
user := account.Users[claims.UserId]
|
||||
if user == nil {
|
||||
// this is not really possible because we got an account by user ID
|
||||
return nil, nil, fmt.Errorf("user %s not found", claims.UserId)
|
||||
}
|
||||
|
||||
return account, user, nil
|
||||
}
|
||||
|
||||
func toHTTPError(err error, w http.ResponseWriter) {
|
||||
errStatus, ok := status.FromError(err)
|
||||
if ok && errStatus.Code() == codes.Internal {
|
||||
http.Error(w, errStatus.String(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if ok && errStatus.Code() == codes.NotFound {
|
||||
http.Error(w, errStatus.String(), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
if ok && errStatus.Code() == codes.InvalidArgument {
|
||||
http.Error(w, errStatus.String(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", errStatus.String())
|
||||
log.Error(unhandledMSG)
|
||||
http.Error(w, unhandledMSG, http.StatusInternalServerError)
|
||||
}
|
105
management/server/http/util/util.go
Normal file
105
management/server/http/util/util.go
Normal file
@ -0,0 +1,105 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// WriteJSONObject simply writes object to the HTTP reponse in JSON format
|
||||
func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
err := json.NewEncoder(w).Encode(obj)
|
||||
if err != nil {
|
||||
WriteError(err, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Duration is used strictly for JSON requests/responses due to duration marshalling issues
|
||||
type Duration struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
// MarshalJSON marshals the duration
|
||||
func (d Duration) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(d.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals the duration
|
||||
func (d *Duration) UnmarshalJSON(b []byte) error {
|
||||
var v interface{}
|
||||
if err := json.Unmarshal(b, &v); err != nil {
|
||||
return err
|
||||
}
|
||||
switch value := v.(type) {
|
||||
case float64:
|
||||
d.Duration = time.Duration(value)
|
||||
return nil
|
||||
case string:
|
||||
var err error
|
||||
d.Duration, err = time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return errors.New("invalid duration")
|
||||
}
|
||||
}
|
||||
|
||||
// WriteErrorResponse prepares and writes an error response i nJSON
|
||||
func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
|
||||
type errorResponse struct {
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code"`
|
||||
}
|
||||
|
||||
w.WriteHeader(httpStatus)
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
err := json.NewEncoder(w).Encode(&errorResponse{
|
||||
Message: errMsg,
|
||||
Code: httpStatus,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, "failed handling request", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteError converts an error to an JSON error response.
|
||||
// If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise
|
||||
func WriteError(err error, w http.ResponseWriter) {
|
||||
errStatus, ok := status.FromError(err)
|
||||
httpStatus := http.StatusInternalServerError
|
||||
msg := "internal server error"
|
||||
if ok {
|
||||
switch errStatus.Type() {
|
||||
case status.UserAlreadyExists:
|
||||
httpStatus = http.StatusConflict
|
||||
case status.AlreadyExists:
|
||||
httpStatus = http.StatusConflict
|
||||
case status.PreconditionFailed:
|
||||
httpStatus = http.StatusPreconditionFailed
|
||||
case status.PermissionDenied:
|
||||
httpStatus = http.StatusForbidden
|
||||
case status.NotFound:
|
||||
httpStatus = http.StatusNotFound
|
||||
case status.Internal:
|
||||
httpStatus = http.StatusInternalServerError
|
||||
case status.InvalidArgument:
|
||||
httpStatus = http.StatusUnprocessableEntity
|
||||
default:
|
||||
}
|
||||
msg = err.Error()
|
||||
} else {
|
||||
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error())
|
||||
log.Error(unhandledMSG)
|
||||
}
|
||||
|
||||
WriteErrorResponse(msg, httpStatus, w)
|
||||
}
|
@ -59,7 +59,7 @@ type MockAccountManager struct {
|
||||
DeleteNameServerGroupFunc func(accountID, nsGroupID string) error
|
||||
ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error)
|
||||
CreateUserFunc func(accountID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error)
|
||||
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
||||
}
|
||||
|
||||
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
||||
@ -113,7 +113,7 @@ func (am *MockAccountManager) CreateSetupKey(
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountByUserOrAccountId mock implementation of GetAccountByUserOrAccountId from server.AccountManager interface
|
||||
// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountByUserOrAccountID(
|
||||
userId, accountId, domain string,
|
||||
) (*server.Account, error) {
|
||||
@ -462,11 +462,12 @@ func (am *MockAccountManager) CreateUser(accountID string, invite *server.UserIn
|
||||
}
|
||||
|
||||
// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User,
|
||||
error) {
|
||||
if am.GetAccountFromTokenFunc != nil {
|
||||
return am.GetAccountFromTokenFunc(claims)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
|
||||
return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
|
||||
}
|
||||
|
||||
// GetPeers mocks GetPeers of the AccountManager interface
|
||||
|
@ -3,10 +3,9 @@ package server
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"strconv"
|
||||
"unicode/utf8"
|
||||
)
|
||||
@ -66,7 +65,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string)
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nsGroup, found := account.NameServerGroups[nsGroupID]
|
||||
@ -74,7 +73,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string)
|
||||
return nsGroup.Copy(), nil
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID)
|
||||
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
|
||||
}
|
||||
|
||||
// CreateNameServerGroup creates and saves a new nameserver group
|
||||
@ -85,7 +84,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newNSGroup := &nbdns.NameServerGroup{
|
||||
@ -119,7 +118,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return newNSGroup.Copy(), status.Errorf(codes.Unavailable, "failed to update peers after create nameserver %s", name)
|
||||
return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after create nameserver %s", name)
|
||||
}
|
||||
|
||||
return newNSGroup.Copy(), nil
|
||||
@ -132,12 +131,12 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupTo
|
||||
defer unlock()
|
||||
|
||||
if nsGroupToSave == nil {
|
||||
return status.Errorf(codes.InvalidArgument, "nameserver group provided is nil")
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.NotFound, "account not found")
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateNameServerGroup(true, nsGroupToSave, account)
|
||||
@ -156,7 +155,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupTo
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return status.Errorf(codes.Unavailable, "failed to update peers after update nameserver %s", nsGroupToSave.Name)
|
||||
return status.Errorf(status.Internal, "failed to update peers after update nameserver %s", nsGroupToSave.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -170,16 +169,16 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(operations) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "operations shouldn't be empty")
|
||||
return nil, status.Errorf(status.InvalidArgument, "operations shouldn't be empty")
|
||||
}
|
||||
|
||||
nsGroupToUpdate, ok := account.NameServerGroups[nsGroupID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
|
||||
return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
|
||||
}
|
||||
|
||||
newNSGroup := nsGroupToUpdate.Copy()
|
||||
@ -187,12 +186,12 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
|
||||
for _, operation := range operations {
|
||||
valuesCount := len(operation.Values)
|
||||
if valuesCount < 1 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String())
|
||||
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String())
|
||||
}
|
||||
|
||||
for _, value := range operation.Values {
|
||||
if value == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String())
|
||||
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String())
|
||||
}
|
||||
}
|
||||
switch operation.Type {
|
||||
@ -200,7 +199,7 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
|
||||
newNSGroup.Description = operation.Values[0]
|
||||
case UpdateNameServerGroupName:
|
||||
if valuesCount > 1 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount)
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount)
|
||||
}
|
||||
err = validateNSGroupName(operation.Values[0], nsGroupID, account.NameServerGroups)
|
||||
if err != nil {
|
||||
@ -230,13 +229,13 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
|
||||
case UpdateNameServerGroupEnabled:
|
||||
enabled, err := strconv.ParseBool(operation.Values[0])
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
|
||||
}
|
||||
newNSGroup.Enabled = enabled
|
||||
case UpdateNameServerGroupPrimary:
|
||||
primary, err := strconv.ParseBool(operation.Values[0])
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0])
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0])
|
||||
}
|
||||
newNSGroup.Primary = primary
|
||||
case UpdateNameServerGroupDomains:
|
||||
@ -259,7 +258,7 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return newNSGroup.Copy(), status.Errorf(codes.Unavailable, "failed to update peers after update nameserver %s", newNSGroup.Name)
|
||||
return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after update nameserver %s", newNSGroup.Name)
|
||||
}
|
||||
|
||||
return newNSGroup.Copy(), nil
|
||||
@ -273,7 +272,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID stri
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.NotFound, "account not found")
|
||||
return err
|
||||
}
|
||||
|
||||
delete(account.NameServerGroups, nsGroupID)
|
||||
@ -287,7 +286,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID stri
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return status.Errorf(codes.Unavailable, "failed to update peers after deleting nameserver %s", nsGroupID)
|
||||
return status.Errorf(status.Internal, "failed to update peers after deleting nameserver %s", nsGroupID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -301,7 +300,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdn
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups))
|
||||
@ -318,7 +317,7 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
|
||||
nsGroupID = nameserverGroup.ID
|
||||
_, found := account.NameServerGroups[nsGroupID]
|
||||
if !found {
|
||||
return status.Errorf(codes.NotFound, "nameserver group with ID %s was not found", nsGroupID)
|
||||
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
|
||||
}
|
||||
}
|
||||
|
||||
@ -347,17 +346,17 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
|
||||
|
||||
func validateDomainInput(primary bool, domains []string) error {
|
||||
if !primary && len(domains) == 0 {
|
||||
return status.Errorf(codes.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
|
||||
" it should be primary or have at least one domain")
|
||||
}
|
||||
if primary && len(domains) != 0 {
|
||||
return status.Errorf(codes.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+
|
||||
" you should set either primary or domain")
|
||||
}
|
||||
for _, domain := range domains {
|
||||
_, valid := dns.IsDomainName(domain)
|
||||
if !valid {
|
||||
return status.Errorf(codes.InvalidArgument, "nameserver group got an invalid domain: %s", domain)
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s", domain)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@ -365,12 +364,12 @@ func validateDomainInput(primary bool, domains []string) error {
|
||||
|
||||
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
|
||||
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
|
||||
return status.Errorf(codes.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
|
||||
}
|
||||
|
||||
for _, nsGroup := range nsGroupMap {
|
||||
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
|
||||
return status.Errorf(codes.InvalidArgument, "a nameserver group with name %s already exist", name)
|
||||
return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name)
|
||||
}
|
||||
}
|
||||
|
||||
@ -380,19 +379,19 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na
|
||||
func validateNSList(list []nbdns.NameServer) error {
|
||||
nsListLenght := len(list)
|
||||
if nsListLenght == 0 || nsListLenght > 2 {
|
||||
return status.Errorf(codes.InvalidArgument, "the list of nameservers should be 1 or 2, got %d", len(list))
|
||||
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 2, got %d", len(list))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateGroups(list []string, groups map[string]*Group) error {
|
||||
if len(list) == 0 {
|
||||
return status.Errorf(codes.InvalidArgument, "the list of group IDs should not be empty")
|
||||
return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty")
|
||||
}
|
||||
|
||||
for _, id := range list {
|
||||
if id == "" {
|
||||
return status.Errorf(codes.InvalidArgument, "group ID should not be empty string")
|
||||
return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
|
||||
}
|
||||
found := false
|
||||
for groupID := range groups {
|
||||
@ -402,7 +401,7 @@ func validateGroups(list []string, groups map[string]*Group) error {
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return status.Errorf(codes.InvalidArgument, "group id %s not found", id)
|
||||
return status.Errorf(status.InvalidArgument, "group id %s not found", id)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"github.com/c-robinson/iplib"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/rs/xid"
|
||||
"math/rand"
|
||||
@ -93,7 +94,7 @@ func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
|
||||
ips, _ := generateIPs(&ipNet, takenIPMap)
|
||||
|
||||
if len(ips) == 0 {
|
||||
return nil, Errorf(PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String())
|
||||
return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String())
|
||||
}
|
||||
|
||||
// pick a random IP
|
||||
|
@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
@ -9,8 +10,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// PeerSystemMeta is a metadata of a Peer machine system
|
||||
@ -162,7 +161,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Pe
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//TODO Peer.ID migration: we will need to replace search by ID here
|
||||
@ -208,7 +207,7 @@ func (am *DefaultAccountManager) DeletePeer(accountID string, peerPubKey string)
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peer, err := account.FindPeerByPubKey(peerPubKey)
|
||||
@ -258,7 +257,7 @@ func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
@ -267,7 +266,7 @@ func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.NotFound, "peer with IP %s not found", peerIP)
|
||||
return nil, status.Errorf(status.NotFound, "peer with IP %s not found", peerIP)
|
||||
}
|
||||
|
||||
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
||||
@ -275,7 +274,7 @@ func (am *DefaultAccountManager) GetNetworkMap(peerPubKey string) (*NetworkMap,
|
||||
|
||||
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "Invalid peer key %s", peerPubKey)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
aclPeers := am.getPeersByACL(account, peerPubKey)
|
||||
@ -306,7 +305,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(peerPubKey string) (*Network, er
|
||||
|
||||
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "invalid peer key %s", peerPubKey)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return account.Network.Copy(), err
|
||||
@ -332,7 +331,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
|
||||
account, err = am.Store.GetAccountBySetupKey(setupKey)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, Errorf(AccountNotFound, "failed adding new peer: account not found")
|
||||
return nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||
@ -352,7 +351,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
|
||||
}
|
||||
|
||||
if !sk.IsValid() {
|
||||
return nil, Errorf(PreconditionFailed, "couldn't add peer: setup key is invalid")
|
||||
return nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
||||
}
|
||||
|
||||
account.SetupKeys[sk.Key] = sk.IncrementUsage()
|
||||
@ -418,7 +417,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
|
||||
account.Network.IncSerial()
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed adding peer")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newPeer, nil
|
||||
@ -563,13 +562,13 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) error {
|
||||
for _, peer := range peers {
|
||||
remotePeerNetworkMap, err := am.GetNetworkMap(peer.Key)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "unable to fetch network map for peer %s, error: %v", peer.Key, err)
|
||||
return err
|
||||
}
|
||||
|
||||
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap)
|
||||
err = am.peersUpdateManager.SendUpdate(peer.Key, &UpdateMessage{Update: update})
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "unable to send update for peer %s, error: %v", peer.Key, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,11 +2,10 @@ package server
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"unicode/utf8"
|
||||
@ -66,7 +65,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
@ -75,7 +74,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
|
||||
}
|
||||
|
||||
if !user.IsAdmin() {
|
||||
return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes")
|
||||
return nil, status.Errorf(status.PermissionDenied, "Only administrators can view Network Routes")
|
||||
}
|
||||
|
||||
wantedRoute, found := account.Routes[routeID]
|
||||
@ -83,7 +82,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
|
||||
return wantedRoute, nil
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.NotFound, "route with ID %s not found", routeID)
|
||||
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
||||
}
|
||||
|
||||
// checkPrefixPeerExists checks the combination of prefix and peer id, if it exists returns an error, otherwise returns nil
|
||||
@ -101,14 +100,14 @@ func (am *DefaultAccountManager) checkPrefixPeerExists(accountID, peer string, p
|
||||
routesWithPrefix := account.GetRoutesByPrefix(prefix)
|
||||
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
return nil
|
||||
}
|
||||
return status.Errorf(codes.InvalidArgument, "failed to parse prefix %s", prefix.String())
|
||||
return status.Errorf(status.InvalidArgument, "failed to parse prefix %s", prefix.String())
|
||||
}
|
||||
for _, prefixRoute := range routesWithPrefix {
|
||||
if prefixRoute.Peer == peer {
|
||||
return status.Errorf(codes.AlreadyExists, "failed a route with prefix %s and peer already exist", prefix.String())
|
||||
return status.Errorf(status.AlreadyExists, "failed a route with prefix %s and peer already exist", prefix.String())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@ -121,13 +120,13 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var newRoute route.Route
|
||||
prefixType, newPrefix, err := route.ParseNetwork(network)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse IP %s", network)
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", network)
|
||||
}
|
||||
err = am.checkPrefixPeerExists(accountID, peer, newPrefix)
|
||||
if err != nil {
|
||||
@ -137,16 +136,16 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de
|
||||
if peer != "" {
|
||||
_, peerExist := account.Peers[peer]
|
||||
if !peerExist {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to find Peer %s", peer)
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to find Peer %s", peer)
|
||||
}
|
||||
}
|
||||
|
||||
if metric < route.MinMetric || metric > route.MaxMetric {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
||||
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(netID) > route.MaxNetIDChar || netID == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
}
|
||||
|
||||
newRoute.Peer = peer
|
||||
@ -173,7 +172,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return &newRoute, status.Errorf(codes.Unavailable, "failed to update peers after create route %s", newPrefix)
|
||||
return &newRoute, status.Errorf(status.Internal, "failed to update peers after create route %s", newPrefix)
|
||||
}
|
||||
return &newRoute, nil
|
||||
}
|
||||
@ -184,30 +183,30 @@ func (am *DefaultAccountManager) SaveRoute(accountID string, routeToSave *route.
|
||||
defer unlock()
|
||||
|
||||
if routeToSave == nil {
|
||||
return status.Errorf(codes.InvalidArgument, "route provided is nil")
|
||||
return status.Errorf(status.InvalidArgument, "route provided is nil")
|
||||
}
|
||||
|
||||
if !routeToSave.Network.IsValid() {
|
||||
return status.Errorf(codes.InvalidArgument, "invalid Prefix %s", routeToSave.Network.String())
|
||||
return status.Errorf(status.InvalidArgument, "invalid Prefix %s", routeToSave.Network.String())
|
||||
}
|
||||
|
||||
if routeToSave.Metric < route.MinMetric || routeToSave.Metric > route.MaxMetric {
|
||||
return status.Errorf(codes.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
||||
return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(routeToSave.NetID) > route.MaxNetIDChar || routeToSave.NetID == "" {
|
||||
return status.Errorf(codes.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.NotFound, "account not found")
|
||||
return err
|
||||
}
|
||||
|
||||
if routeToSave.Peer != "" {
|
||||
_, peerExist := account.Peers[routeToSave.Peer]
|
||||
if !peerExist {
|
||||
return status.Errorf(codes.InvalidArgument, "failed to find Peer %s", routeToSave.Peer)
|
||||
return status.Errorf(status.InvalidArgument, "failed to find Peer %s", routeToSave.Peer)
|
||||
}
|
||||
}
|
||||
|
||||
@ -228,12 +227,12 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routeToUpdate, ok := account.Routes[routeID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.NotFound, "route %s no longer exists", routeID)
|
||||
return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID)
|
||||
}
|
||||
|
||||
newRoute := routeToUpdate.Copy()
|
||||
@ -241,7 +240,7 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
|
||||
for _, operation := range operations {
|
||||
|
||||
if len(operation.Values) != 1 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid number of values, it should be 1", operation.Type.String())
|
||||
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be 1", operation.Type.String())
|
||||
}
|
||||
|
||||
switch operation.Type {
|
||||
@ -249,13 +248,13 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
|
||||
newRoute.Description = operation.Values[0]
|
||||
case UpdateRouteNetworkIdentifier:
|
||||
if utf8.RuneCountInString(operation.Values[0]) > route.MaxNetIDChar || operation.Values[0] == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
}
|
||||
newRoute.NetID = operation.Values[0]
|
||||
case UpdateRouteNetwork:
|
||||
prefixType, prefix, err := route.ParseNetwork(operation.Values[0])
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse IP %s", operation.Values[0])
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", operation.Values[0])
|
||||
}
|
||||
err = am.checkPrefixPeerExists(accountID, routeToUpdate.Peer, prefix)
|
||||
if err != nil {
|
||||
@ -267,7 +266,7 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
|
||||
if operation.Values[0] != "" {
|
||||
_, peerExist := account.Peers[operation.Values[0]]
|
||||
if !peerExist {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to find Peer %s", operation.Values[0])
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to find Peer %s", operation.Values[0])
|
||||
}
|
||||
}
|
||||
|
||||
@ -279,10 +278,10 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
|
||||
case UpdateRouteMetric:
|
||||
metric, err := strconv.Atoi(operation.Values[0])
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse metric %s, not int", operation.Values[0])
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, not int", operation.Values[0])
|
||||
}
|
||||
if metric < route.MinMetric || metric > route.MaxMetric {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse metric %s, value should be %d > N < %d",
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, value should be %d > N < %d",
|
||||
operation.Values[0],
|
||||
route.MinMetric,
|
||||
route.MaxMetric,
|
||||
@ -292,13 +291,13 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
|
||||
case UpdateRouteMasquerade:
|
||||
masquerade, err := strconv.ParseBool(operation.Values[0])
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse masquerade %s, not boolean", operation.Values[0])
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse masquerade %s, not boolean", operation.Values[0])
|
||||
}
|
||||
newRoute.Masquerade = masquerade
|
||||
case UpdateRouteEnabled:
|
||||
enabled, err := strconv.ParseBool(operation.Values[0])
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
|
||||
}
|
||||
newRoute.Enabled = enabled
|
||||
}
|
||||
@ -313,7 +312,7 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
|
||||
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update account peers")
|
||||
return nil, status.Errorf(status.Internal, "failed to update account peers")
|
||||
}
|
||||
return newRoute, nil
|
||||
}
|
||||
@ -325,7 +324,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error {
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.NotFound, "account not found")
|
||||
return err
|
||||
}
|
||||
|
||||
delete(account.Routes, routeID)
|
||||
@ -345,7 +344,7 @@ func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
@ -354,7 +353,7 @@ func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.
|
||||
}
|
||||
|
||||
if !user.IsAdmin() {
|
||||
return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes")
|
||||
return nil, status.Errorf(status.PermissionDenied, "Only administrators can view Network Routes")
|
||||
}
|
||||
|
||||
routes := make([]*route.Route, 0, len(account.Routes))
|
||||
|
@ -1,8 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@ -95,7 +94,7 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rul
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
@ -104,7 +103,7 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rul
|
||||
}
|
||||
|
||||
if !user.IsAdmin() {
|
||||
return nil, Errorf(PermissionDenied, "only admins are allowed to view rules")
|
||||
return nil, status.Errorf(status.PermissionDenied, "only admins are allowed to view rules")
|
||||
}
|
||||
|
||||
rule, ok := account.Rules[ruleID]
|
||||
@ -112,7 +111,7 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rul
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.NotFound, "rule with ID %s not found", ruleID)
|
||||
return nil, status.Errorf(status.NotFound, "rule with ID %s not found", ruleID)
|
||||
}
|
||||
|
||||
// SaveRule of ACL in the store
|
||||
@ -122,7 +121,7 @@ func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error {
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.NotFound, "account not found")
|
||||
return err
|
||||
}
|
||||
|
||||
account.Rules[rule.ID] = rule
|
||||
@ -143,12 +142,12 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ruleToUpdate, ok := account.Rules[ruleID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.NotFound, "rule %s no longer exists", ruleID)
|
||||
return nil, status.Errorf(status.NotFound, "rule %s no longer exists", ruleID)
|
||||
}
|
||||
|
||||
rule := ruleToUpdate.Copy()
|
||||
@ -161,7 +160,7 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
|
||||
rule.Description = operation.Values[0]
|
||||
case UpdateRuleFlow:
|
||||
if operation.Values[0] != TrafficFlowBidirectString {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse flow")
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse flow")
|
||||
}
|
||||
rule.Flow = TrafficFlowBidirect
|
||||
case UpdateRuleStatus:
|
||||
@ -170,7 +169,7 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
|
||||
} else if strings.ToLower(operation.Values[0]) == "false" {
|
||||
rule.Disabled = false
|
||||
} else {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse status")
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse status")
|
||||
}
|
||||
case UpdateSourceGroups:
|
||||
rule.Source = operation.Values
|
||||
@ -204,7 +203,7 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
|
||||
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to update account peers")
|
||||
return nil, status.Errorf(status.Internal, "failed to update account peers")
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
@ -217,7 +216,7 @@ func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error {
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.NotFound, "account not found")
|
||||
return err
|
||||
}
|
||||
|
||||
delete(account.Rules, ruleID)
|
||||
@ -237,7 +236,7 @@ func (am *DefaultAccountManager) ListRules(accountID, userID string) ([]*Rule, e
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
@ -246,7 +245,7 @@ func (am *DefaultAccountManager) ListRules(accountID, userID string) ([]*Rule, e
|
||||
}
|
||||
|
||||
if !user.IsAdmin() {
|
||||
return nil, Errorf(PermissionDenied, "Only Administrators can view Access Rules")
|
||||
return nil, status.Errorf(status.PermissionDenied, "Only Administrators can view Access Rules")
|
||||
}
|
||||
|
||||
rules := make([]*Rule, 0, len(account.Rules))
|
||||
|
@ -1,10 +1,8 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"hash/fnv"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -183,12 +181,12 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, group := range autoGroups {
|
||||
if _, ok := account.Groups[group]; !ok {
|
||||
return nil, fmt.Errorf("group %s doesn't exist", group)
|
||||
return nil, status.Errorf(status.NotFound, "group %s doesn't exist", group)
|
||||
}
|
||||
}
|
||||
|
||||
@ -197,7 +195,7 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed adding account key")
|
||||
return nil, status.Errorf(status.Internal, "failed adding account key")
|
||||
}
|
||||
|
||||
return setupKey, nil
|
||||
@ -212,12 +210,12 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
|
||||
defer unlock()
|
||||
|
||||
if keyToSave == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "provided setup key to update is nil")
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var oldKey *SetupKey
|
||||
@ -228,7 +226,7 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
|
||||
}
|
||||
}
|
||||
if oldKey == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "setup key not found")
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
}
|
||||
|
||||
// only auto groups, revoked status, and name can be updated for now
|
||||
@ -253,7 +251,7 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set
|
||||
defer unlock()
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
@ -282,7 +280,7 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
@ -298,7 +296,7 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*
|
||||
}
|
||||
}
|
||||
if foundKey == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "setup key not found")
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
}
|
||||
|
||||
// the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file)
|
||||
|
72
management/server/status/error.go
Normal file
72
management/server/status/error.go
Normal file
@ -0,0 +1,72 @@
|
||||
package status
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
// UserAlreadyExists indicates that user already exists
|
||||
UserAlreadyExists Type = 1
|
||||
|
||||
// PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled
|
||||
PreconditionFailed Type = 2
|
||||
|
||||
// PermissionDenied indicates that user has no permissions to view data
|
||||
PermissionDenied Type = 3
|
||||
|
||||
// NotFound indicates that the object wasn't found in the system (or under a given Account)
|
||||
NotFound Type = 4
|
||||
|
||||
// Internal indicates some generic internal error
|
||||
Internal Type = 5
|
||||
|
||||
// InvalidArgument indicates some generic invalid argument error
|
||||
InvalidArgument Type = 6
|
||||
|
||||
// AlreadyExists indicates a generic error when an object already exists in the system
|
||||
AlreadyExists Type = 7
|
||||
|
||||
// Unauthorized indicates that user is not authorized
|
||||
Unauthorized Type = 8
|
||||
|
||||
// BadRequest indicates that user is not authorized
|
||||
BadRequest Type = 9
|
||||
)
|
||||
|
||||
// Type is a type of the Error
|
||||
type Type int32
|
||||
|
||||
// Error is an internal error
|
||||
type Error struct {
|
||||
ErrorType Type
|
||||
Message string
|
||||
}
|
||||
|
||||
// Type returns the Type of the error
|
||||
func (e *Error) Type() Type {
|
||||
return e.ErrorType
|
||||
}
|
||||
|
||||
// Error is an error string
|
||||
func (e *Error) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// Errorf returns Error(ErrorType, fmt.Sprintf(format, a...)).
|
||||
func Errorf(errorType Type, format string, a ...interface{}) error {
|
||||
return &Error{
|
||||
ErrorType: errorType,
|
||||
Message: fmt.Sprintf(format, a...),
|
||||
}
|
||||
}
|
||||
|
||||
// FromError returns Error, true if the provided error is of type of Error. nil, false otherwise
|
||||
func FromError(err error) (s *Error, ok bool) {
|
||||
if err == nil {
|
||||
return nil, true
|
||||
}
|
||||
if e, ok := err.(*Error); ok {
|
||||
return e, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
@ -3,8 +3,7 @@ package server
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@ -123,7 +122,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
|
||||
defer unlock()
|
||||
|
||||
if am.idpManager == nil {
|
||||
return nil, Errorf(PreconditionFailed, "IdP manager must be enabled to send user invites")
|
||||
return nil, status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites")
|
||||
}
|
||||
|
||||
if invite == nil {
|
||||
@ -132,7 +131,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, Errorf(AccountNotFound, "account %s doesn't exist", accountID)
|
||||
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
|
||||
}
|
||||
|
||||
// check if the user is already registered with this email => reject
|
||||
@ -142,7 +141,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
|
||||
}
|
||||
|
||||
if user != nil {
|
||||
return nil, Errorf(UserAlreadyExists, "user has an existing account")
|
||||
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
|
||||
}
|
||||
|
||||
users, err := am.idpManager.GetUserByEmail(invite.Email)
|
||||
@ -151,7 +150,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
|
||||
}
|
||||
|
||||
if len(users) > 0 {
|
||||
return nil, Errorf(UserAlreadyExists, "user has an existing account")
|
||||
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
|
||||
}
|
||||
|
||||
idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID)
|
||||
@ -188,25 +187,24 @@ func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*User
|
||||
defer unlock()
|
||||
|
||||
if update == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "provided user update is nil")
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, newGroupID := range update.AutoGroups {
|
||||
if _, ok := account.Groups[newGroupID]; !ok {
|
||||
return nil,
|
||||
status.Errorf(codes.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
||||
newGroupID, update.Id)
|
||||
}
|
||||
}
|
||||
|
||||
oldUser := account.Users[update.Id]
|
||||
if oldUser == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "update not found")
|
||||
return nil, status.Errorf(status.NotFound, "update not found")
|
||||
}
|
||||
|
||||
// only auto groups, revoked status, and name can be updated for now
|
||||
@ -226,7 +224,7 @@ func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*User
|
||||
return nil, err
|
||||
}
|
||||
if userData == nil {
|
||||
return nil, status.Errorf(codes.NotFound, "user %s not found in the IdP", newUser.Id)
|
||||
return nil, status.Errorf(status.NotFound, "user %s not found in the IdP", newUser.Id)
|
||||
}
|
||||
return newUser.toUserInfo(userData)
|
||||
}
|
||||
@ -242,14 +240,14 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
|
||||
|
||||
account, err := am.Store.GetAccountByUser(userID)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
account, err = am.newAccount(userID, lowerDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed creating account")
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// other error
|
||||
@ -263,7 +261,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
|
||||
account.Domain = lowerDomain
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed updating account with domain")
|
||||
return nil, status.Errorf(status.Internal, "failed updating account with domain")
|
||||
}
|
||||
}
|
||||
|
||||
@ -272,14 +270,14 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
|
||||
|
||||
// IsUserAdmin flag for current user authenticated by JWT token
|
||||
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
|
||||
account, err := am.GetAccountFromToken(claims)
|
||||
account, _, err := am.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get account: %v", err)
|
||||
}
|
||||
|
||||
user, ok := account.Users[claims.UserId]
|
||||
if !ok {
|
||||
return false, fmt.Errorf("no such user")
|
||||
return false, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
return user.Role == UserRoleAdmin, nil
|
||||
|
@ -1,8 +1,7 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
@ -108,13 +107,13 @@ func (r *Route) IsEqual(other *Route) bool {
|
||||
func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) {
|
||||
prefix, err := netip.ParsePrefix(networkString)
|
||||
if err != nil {
|
||||
return InvalidNetwork, netip.Prefix{}, err
|
||||
return InvalidNetwork, netip.Prefix{}, status.Errorf(status.InvalidArgument, "invalid network %s", networkString)
|
||||
}
|
||||
|
||||
masked := prefix.Masked()
|
||||
|
||||
if !masked.IsValid() {
|
||||
return InvalidNetwork, netip.Prefix{}, status.Errorf(codes.InvalidArgument, "invalid range %s", networkString)
|
||||
return InvalidNetwork, netip.Prefix{}, status.Errorf(status.InvalidArgument, "invalid range %s", networkString)
|
||||
}
|
||||
|
||||
if masked.Addr().Is6() {
|
||||
|
Loading…
Reference in New Issue
Block a user