mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-24 17:13:30 +01:00
Replace gRPC errors in business logic with internal ones (#558)
This commit is contained in:
parent
1db4027bea
commit
509d23c7cf
@ -8,12 +8,11 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
gocache "github.com/patrickmn/go-cache"
|
gocache "github.com/patrickmn/go-cache"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@ -52,7 +51,7 @@ type AccountManager interface {
|
|||||||
SaveUser(accountID string, key *User) (*UserInfo, error)
|
SaveUser(accountID string, key *User) (*UserInfo, error)
|
||||||
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
||||||
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
|
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
|
||||||
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error)
|
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
||||||
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
|
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||||
AccountExists(accountId string) (*bool, error)
|
AccountExists(accountId string) (*bool, error)
|
||||||
GetPeer(peerKey string) (*Peer, error)
|
GetPeer(peerKey string) (*Peer, error)
|
||||||
@ -265,14 +264,14 @@ func (a *Account) FindPeerByPubKey(peerPubKey string) (*Peer, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, status.Errorf(codes.NotFound, "peer with the public key %s not found", peerPubKey)
|
return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", peerPubKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindUser looks for a given user in the Account or returns error if user wasn't found.
|
// FindUser looks for a given user in the Account or returns error if user wasn't found.
|
||||||
func (a *Account) FindUser(userID string) (*User, error) {
|
func (a *Account) FindUser(userID string) (*User, error) {
|
||||||
user := a.Users[userID]
|
user := a.Users[userID]
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, Errorf(UserNotFound, "user %s not found", userID)
|
return nil, status.Errorf(status.NotFound, "user %s not found", userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
@ -282,7 +281,7 @@ func (a *Account) FindUser(userID string) (*User, error) {
|
|||||||
func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) {
|
func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) {
|
||||||
key := a.SetupKeys[setupKey]
|
key := a.SetupKeys[setupKey]
|
||||||
if key == nil {
|
if key == nil {
|
||||||
return nil, Errorf(SetupKeyNotFound, "setup key not found")
|
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
return key, nil
|
return key, nil
|
||||||
@ -458,14 +457,14 @@ func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, er
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
log.Warnf("an account with ID already exists, retrying...")
|
log.Warnf("an account with ID already exists, retrying...")
|
||||||
continue
|
continue
|
||||||
} else if statusErr.Code() == codes.NotFound {
|
} else if statusErr.Type() == status.NotFound {
|
||||||
return newAccountWithId(accountId, userID, domain), nil
|
return newAccountWithId(accountId, userID, domain), nil
|
||||||
} else {
|
} else {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, status.Errorf(codes.Internal, "error while creating new account")
|
return nil, status.Errorf(status.Internal, "error while creating new account")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) warmupIDPCache() error {
|
func (am *DefaultAccountManager) warmupIDPCache() error {
|
||||||
@ -492,7 +491,7 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID,
|
|||||||
} else if userID != "" {
|
} else if userID != "" {
|
||||||
account, err := am.GetOrCreateAccountByUser(userID, domain)
|
account, err := am.GetOrCreateAccountByUser(userID, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userID)
|
return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID)
|
||||||
}
|
}
|
||||||
err = am.addAccountIDToIDPAppMeta(userID, account)
|
err = am.addAccountIDToIDPAppMeta(userID, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -501,7 +500,7 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID,
|
|||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, status.Errorf(codes.NotFound, "no valid user or account Id provided")
|
return nil, status.Errorf(status.NotFound, "no valid user or account Id provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
func isNil(i idp.Manager) bool {
|
func isNil(i idp.Manager) bool {
|
||||||
@ -531,11 +530,7 @@ func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(
|
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
|
||||||
codes.Internal,
|
|
||||||
"updating user's app metadata failed with: %v",
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
// refresh cache to reflect the update
|
// refresh cache to reflect the update
|
||||||
_, err = am.refreshCache(account.Id)
|
_, err = am.refreshCache(account.Id)
|
||||||
@ -662,11 +657,8 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a
|
|||||||
}
|
}
|
||||||
|
|
||||||
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
|
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
|
||||||
func (am *DefaultAccountManager) updateAccountDomainAttributes(
|
func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims,
|
||||||
account *Account,
|
primaryDomain bool) error {
|
||||||
claims jwtclaims.AuthorizationClaims,
|
|
||||||
primaryDomain bool,
|
|
||||||
) error {
|
|
||||||
account.IsDomainPrimaryAccount = primaryDomain
|
account.IsDomainPrimaryAccount = primaryDomain
|
||||||
|
|
||||||
lowerDomain := strings.ToLower(claims.Domain)
|
lowerDomain := strings.ToLower(claims.Domain)
|
||||||
@ -681,7 +673,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(
|
|||||||
|
|
||||||
err := am.Store.SaveAccount(account)
|
err := am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.Internal, "failed saving updated account")
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -723,10 +715,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
|||||||
|
|
||||||
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
||||||
// otherwise it will create a new account and make it primary account for the domain.
|
// otherwise it will create a new account and make it primary account for the domain.
|
||||||
func (am *DefaultAccountManager) handleNewUserAccount(
|
func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
||||||
domainAcc *Account,
|
|
||||||
claims jwtclaims.AuthorizationClaims,
|
|
||||||
) (*Account, error) {
|
|
||||||
var (
|
var (
|
||||||
account *Account
|
account *Account
|
||||||
err error
|
err error
|
||||||
@ -738,7 +727,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(
|
|||||||
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
|
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
|
||||||
err = am.Store.SaveAccount(account)
|
err = am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "failed saving updated account")
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
account, err = am.newAccount(claims.UserId, lowerDomain)
|
account, err = am.newAccount(claims.UserId, lowerDomain)
|
||||||
@ -773,7 +762,7 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
|
|||||||
}
|
}
|
||||||
|
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return status.Errorf(codes.NotFound, "user %s not found in the IdP", userID)
|
return status.Errorf(status.NotFound, "user %s not found in the IdP", userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
|
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
|
||||||
@ -794,7 +783,7 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountFromToken returns an account associated with this token
|
// GetAccountFromToken returns an account associated with this token
|
||||||
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
|
||||||
|
|
||||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||||
// This section is mostly related to self-hosted installations.
|
// This section is mostly related to self-hosted installations.
|
||||||
@ -806,15 +795,21 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
|||||||
|
|
||||||
account, err := am.getAccountWithAuthorizationClaims(claims)
|
account, err := am.getAccountWithAuthorizationClaims(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
user := account.Users[claims.UserId]
|
||||||
|
if user == nil {
|
||||||
|
// this is not really possible because we got an account by user ID
|
||||||
|
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.redeemInvite(account, claims.UserId)
|
err = am.redeemInvite(account, claims.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return account, nil
|
return account, user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
|
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
|
||||||
@ -857,9 +852,12 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
|
|||||||
|
|
||||||
// We checked if the domain has a primary account already
|
// We checked if the domain has a primary account already
|
||||||
domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain)
|
domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain)
|
||||||
accStatus, _ := status.FromError(err)
|
if err != nil {
|
||||||
if accStatus.Code() != codes.OK && accStatus.Code() != codes.NotFound {
|
// if NotFound we are good to continue, otherwise return error
|
||||||
return nil, err
|
e, ok := status.FromError(err)
|
||||||
|
if !ok || e.Type() != status.NotFound {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetAccountByUser(claims.UserId)
|
account, err := am.Store.GetAccountByUser(claims.UserId)
|
||||||
@ -869,7 +867,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return account, nil
|
return account, nil
|
||||||
} else if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||||
return am.handleNewUserAccount(domainAccount, claims)
|
return am.handleNewUserAccount(domainAccount, claims)
|
||||||
} else {
|
} else {
|
||||||
// other error
|
// other error
|
||||||
@ -891,7 +889,7 @@ func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error)
|
|||||||
var res bool
|
var res bool
|
||||||
_, err := am.Store.GetAccount(accountID)
|
_, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||||
res = false
|
res = false
|
||||||
return &res, nil
|
return &res, nil
|
||||||
} else {
|
} else {
|
||||||
|
@ -314,7 +314,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
|||||||
testCase.inputClaims.AccountId = initAccount.Id
|
testCase.inputClaims.AccountId = initAccount.Id
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := manager.GetAccountFromToken(testCase.inputClaims)
|
account, _, err := manager.GetAccountFromToken(testCase.inputClaims)
|
||||||
require.NoError(t, err, "support function failed")
|
require.NoError(t, err, "support function failed")
|
||||||
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
|
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
|
||||||
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
|
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
|
||||||
|
@ -1,61 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// UserAlreadyExists indicates that user already exists
|
|
||||||
UserAlreadyExists ErrorType = iota
|
|
||||||
// AccountNotFound indicates that specified account hasn't been found
|
|
||||||
AccountNotFound
|
|
||||||
// PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled
|
|
||||||
PreconditionFailed
|
|
||||||
|
|
||||||
// UserNotFound indicates that user wasn't found in the system (or under a given Account)
|
|
||||||
UserNotFound
|
|
||||||
|
|
||||||
// PermissionDenied indicates that user has no permissions to view data
|
|
||||||
PermissionDenied
|
|
||||||
|
|
||||||
// SetupKeyNotFound indicates that the setup key wasn't found in the system (or under a given Account)
|
|
||||||
SetupKeyNotFound
|
|
||||||
)
|
|
||||||
|
|
||||||
// ErrorType is a type of the Error
|
|
||||||
type ErrorType int32
|
|
||||||
|
|
||||||
// Error is an internal error
|
|
||||||
type Error struct {
|
|
||||||
errorType ErrorType
|
|
||||||
message string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Type returns the Type of the error
|
|
||||||
func (e *Error) Type() ErrorType {
|
|
||||||
return e.errorType
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error is an error string
|
|
||||||
func (e *Error) Error() string {
|
|
||||||
return e.message
|
|
||||||
}
|
|
||||||
|
|
||||||
// Errorf returns Error(errorType, fmt.Sprintf(format, a...)).
|
|
||||||
func Errorf(errorType ErrorType, format string, a ...interface{}) error {
|
|
||||||
return &Error{
|
|
||||||
errorType: errorType,
|
|
||||||
message: fmt.Sprintf(format, a...),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// FromError returns Error, true if the provided error is of type of Error. nil, false otherwise
|
|
||||||
func FromError(err error) (s *Error, ok bool) {
|
|
||||||
if err == nil {
|
|
||||||
return nil, true
|
|
||||||
}
|
|
||||||
if e, ok := err.(*Error); ok {
|
|
||||||
return e, true
|
|
||||||
}
|
|
||||||
return nil, false
|
|
||||||
}
|
|
@ -1,6 +1,7 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -8,9 +9,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -192,10 +190,7 @@ func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
|||||||
|
|
||||||
accountID, accountIDFound := s.PrivateDomain2AccountID[strings.ToLower(domain)]
|
accountID, accountIDFound := s.PrivateDomain2AccountID[strings.ToLower(domain)]
|
||||||
if !accountIDFound {
|
if !accountIDFound {
|
||||||
return nil, status.Errorf(
|
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
|
||||||
codes.NotFound,
|
|
||||||
"account not found: provided domain is not registered or is not private",
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
account, err := s.getAccount(accountID)
|
||||||
@ -213,7 +208,7 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
|||||||
|
|
||||||
accountID, accountIDFound := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
accountID, accountIDFound := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
||||||
if !accountIDFound {
|
if !accountIDFound {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found: provided setup key doesn't exists")
|
return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
account, err := s.getAccount(accountID)
|
||||||
@ -239,7 +234,7 @@ func (s *FileStore) GetAllAccounts() (all []*Account) {
|
|||||||
func (s *FileStore) getAccount(accountID string) (*Account, error) {
|
func (s *FileStore) getAccount(accountID string) (*Account, error) {
|
||||||
account, accountFound := s.Accounts[accountID]
|
account, accountFound := s.Accounts[accountID]
|
||||||
if !accountFound {
|
if !accountFound {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, status.Errorf(status.NotFound, "account not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
return account, nil
|
return account, nil
|
||||||
@ -265,7 +260,7 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
|
|||||||
|
|
||||||
accountID, accountIDFound := s.UserID2AccountID[userID]
|
accountID, accountIDFound := s.UserID2AccountID[userID]
|
||||||
if !accountIDFound {
|
if !accountIDFound {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, status.Errorf(status.NotFound, "account not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
account, err := s.getAccount(accountID)
|
||||||
@ -283,7 +278,7 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
|||||||
|
|
||||||
accountID, accountIDFound := s.PeerKeyID2AccountID[peerKey]
|
accountID, accountIDFound := s.PeerKeyID2AccountID[peerKey]
|
||||||
if !accountIDFound {
|
if !accountIDFound {
|
||||||
return nil, status.Errorf(codes.NotFound, "Provided peer key doesn't exists %s", peerKey)
|
return nil, status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := s.getAccount(accountID)
|
account, err := s.getAccount(accountID)
|
||||||
@ -322,7 +317,7 @@ func (s *FileStore) SavePeerStatus(accountID, peerKey string, peerStatus PeerSta
|
|||||||
|
|
||||||
peer := account.Peers[peerKey]
|
peer := account.Peers[peerKey]
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
return status.Errorf(codes.NotFound, "peer %s not found", peerKey)
|
return status.Errorf(status.NotFound, "peer %s not found", peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
peer.Status = &peerStatus
|
peer.Status = &peerStatus
|
||||||
|
@ -1,9 +1,6 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import "github.com/netbirdio/netbird/management/server/status"
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Group of the peers for ACL
|
// Group of the peers for ACL
|
||||||
type Group struct {
|
type Group struct {
|
||||||
@ -53,7 +50,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
group, ok := account.Groups[groupID]
|
group, ok := account.Groups[groupID]
|
||||||
@ -61,7 +58,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er
|
|||||||
return group, nil
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, status.Errorf(codes.NotFound, "group with ID %s not found", groupID)
|
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveGroup object of the peers
|
// SaveGroup object of the peers
|
||||||
@ -72,7 +69,7 @@ func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.NotFound, "account not found")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
account.Groups[group.ID] = group
|
account.Groups[group.ID] = group
|
||||||
@ -94,12 +91,12 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string,
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
groupToUpdate, ok := account.Groups[groupID]
|
groupToUpdate, ok := account.Groups[groupID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, status.Errorf(codes.NotFound, "group %s no longer exists", groupID)
|
return nil, status.Errorf(status.NotFound, "group with ID %s no longer exists", groupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
group := groupToUpdate.Copy()
|
group := groupToUpdate.Copy()
|
||||||
@ -130,7 +127,7 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string,
|
|||||||
|
|
||||||
err = am.updateAccountPeers(account)
|
err = am.updateAccountPeers(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "failed to update account peers")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return group, nil
|
return group, nil
|
||||||
@ -144,7 +141,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.NotFound, "account not found")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(account.Groups, groupID)
|
delete(account.Groups, groupID)
|
||||||
@ -165,7 +162,7 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error)
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
groups := make([]*Group, 0, len(account.Groups))
|
groups := make([]*Group, 0, len(account.Groups))
|
||||||
@ -184,12 +181,12 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.NotFound, "account not found")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
group, ok := account.Groups[groupID]
|
group, ok := account.Groups[groupID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return status.Errorf(codes.NotFound, "group with ID %s not found", groupID)
|
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
add := true
|
add := true
|
||||||
@ -219,12 +216,12 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.NotFound, "account not found")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
group, ok := account.Groups[groupID]
|
group, ok := account.Groups[groupID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return status.Errorf(codes.NotFound, "group with ID %s not found", groupID)
|
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
account.Network.IncSerial()
|
account.Network.IncSerial()
|
||||||
@ -232,7 +229,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str
|
|||||||
if itemID == peerKey {
|
if itemID == peerKey {
|
||||||
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
|
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
|
||||||
if err := am.Store.SaveAccount(account); err != nil {
|
if err := am.Store.SaveAccount(account); err != nil {
|
||||||
return status.Errorf(codes.Internal, "can't save account")
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -248,12 +245,12 @@ func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*P
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, status.Errorf(status.NotFound, "account not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
group, ok := account.Groups[groupID]
|
group, ok := account.Groups[groupID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, status.Errorf(codes.NotFound, "group with ID %s not found", groupID)
|
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
peers := make([]*Peer, 0, len(account.Groups))
|
peers := make([]*Peer, 0, len(account.Groups))
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/golang/protobuf/ptypes/timestamp"
|
"github.com/golang/protobuf/ptypes/timestamp"
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
|
internalStatus "github.com/netbirdio/netbird/management/server/status"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
@ -200,10 +201,6 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
|
|||||||
return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err)
|
return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err)
|
||||||
}
|
}
|
||||||
claims := jwtclaims.ExtractClaimsWithToken(token, s.config.HttpConfig.AuthAudience)
|
claims := jwtclaims.ExtractClaimsWithToken(token, s.config.HttpConfig.AuthAudience)
|
||||||
_, err = s.accountManager.GetAccountFromToken(claims)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
|
||||||
}
|
|
||||||
userID = claims.UserId
|
userID = claims.UserId
|
||||||
} else {
|
} else {
|
||||||
log.Debugln("using setup key to register peer")
|
log.Debugln("using setup key to register peer")
|
||||||
@ -237,14 +234,12 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if e, ok := FromError(err); ok {
|
if e, ok := internalStatus.FromError(err); ok {
|
||||||
switch e.Type() {
|
switch e.Type() {
|
||||||
case PreconditionFailed:
|
case internalStatus.PreconditionFailed:
|
||||||
return nil, status.Errorf(codes.FailedPrecondition, e.message)
|
return nil, status.Errorf(codes.FailedPrecondition, e.Message)
|
||||||
case AccountNotFound:
|
case internalStatus.NotFound:
|
||||||
case SetupKeyNotFound:
|
return nil, status.Errorf(codes.NotFound, e.Message)
|
||||||
case UserNotFound:
|
|
||||||
return nil, status.Errorf(codes.NotFound, e.message)
|
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -301,7 +296,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
|||||||
|
|
||||||
peer, err := s.accountManager.GetPeer(peerKey.String())
|
peer, err := s.accountManager.GetPeer(peerKey.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.NotFound {
|
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
|
||||||
// peer doesn't exist -> check if setup key was provided
|
// peer doesn't exist -> check if setup key was provided
|
||||||
if loginReq.GetJwtToken() == "" && loginReq.GetSetupKey() == "" {
|
if loginReq.GetJwtToken() == "" && loginReq.GetSetupKey() == "" {
|
||||||
// absent setup key or jwt -> permission denied
|
// absent setup key or jwt -> permission denied
|
||||||
@ -387,7 +382,6 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
|
|||||||
case TCP:
|
case TCP:
|
||||||
return proto.HostConfig_TCP
|
return proto.HostConfig_TCP
|
||||||
default:
|
default:
|
||||||
// mbragin: todo something better?
|
|
||||||
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
|
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,10 +2,9 @@ package http
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
"google.golang.org/grpc/codes"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"google.golang.org/grpc/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
@ -33,7 +32,8 @@ func NewGroups(accountManager server.AccountManager, authAudience string) *Group
|
|||||||
|
|
||||||
// GetAllGroupsHandler list for the account
|
// GetAllGroupsHandler list for the account
|
||||||
func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
@ -45,52 +45,54 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
groups = append(groups, toGroupResponse(account, g))
|
groups = append(groups, toGroupResponse(account, g))
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, groups)
|
util.WriteJSONObject(w, groups)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateGroupHandler handles update to a group identified by a given ID
|
// UpdateGroupHandler handles update to a group identified by a given ID
|
||||||
func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
groupID, ok := vars["id"]
|
groupID, ok := vars["id"]
|
||||||
if !ok {
|
if !ok {
|
||||||
http.Error(w, "group ID field is missing", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(groupID) == 0 {
|
if len(groupID) == 0 {
|
||||||
http.Error(w, "group ID can't be empty", http.StatusUnprocessableEntity)
|
util.WriteError(status.Errorf(status.InvalidArgument, "group ID can't be empty"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok = account.Groups[groupID]
|
_, ok = account.Groups[groupID]
|
||||||
if !ok {
|
if !ok {
|
||||||
http.Error(w, fmt.Sprintf("couldn't find group with ID %s", groupID), http.StatusNotFound)
|
util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
allGroup, err := account.GetGroupAll()
|
allGroup, err := account.GetGroupAll()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if allGroup.ID == groupID {
|
if allGroup.ID == groupID {
|
||||||
http.Error(w, "updating group ALL is not allowed", http.StatusMethodNotAllowed)
|
util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req api.PutApiGroupsIdJSONRequestBody
|
var req api.PutApiGroupsIdJSONRequestBody
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
if err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if *req.Name == "" {
|
if *req.Name == "" {
|
||||||
http.Error(w, "group name shouldn't be empty", http.StatusUnprocessableEntity)
|
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -102,53 +104,55 @@ func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
if err := h.accountManager.SaveGroup(account.Id, &group); err != nil {
|
if err := h.accountManager.SaveGroup(account.Id, &group); err != nil {
|
||||||
log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
|
log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, toGroupResponse(account, &group))
|
util.WriteJSONObject(w, toGroupResponse(account, &group))
|
||||||
}
|
}
|
||||||
|
|
||||||
// PatchGroupHandler handles patch updates to a group identified by a given ID
|
// PatchGroupHandler handles patch updates to a group identified by a given ID
|
||||||
func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
groupID := vars["id"]
|
groupID := vars["id"]
|
||||||
if len(groupID) == 0 {
|
if len(groupID) == 0 {
|
||||||
http.Error(w, "invalid group Id", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok := account.Groups[groupID]
|
_, ok := account.Groups[groupID]
|
||||||
if !ok {
|
if !ok {
|
||||||
http.Error(w, fmt.Sprintf("couldn't find group id %s", groupID), http.StatusNotFound)
|
util.WriteError(status.Errorf(status.NotFound, "couldn't find group ID %s", groupID), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
allGroup, err := account.GetGroupAll()
|
allGroup, err := account.GetGroupAll()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if allGroup.ID == groupID {
|
if allGroup.ID == groupID {
|
||||||
http.Error(w, "updating group ALL is not allowed", http.StatusMethodNotAllowed)
|
util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req api.PatchApiGroupsIdJSONRequestBody
|
var req api.PatchApiGroupsIdJSONRequestBody
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
if err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(req) == 0 {
|
if len(req) == 0 {
|
||||||
http.Error(w, "no patch instruction received", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -158,13 +162,13 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
switch patch.Path {
|
switch patch.Path {
|
||||||
case api.GroupPatchOperationPathName:
|
case api.GroupPatchOperationPathName:
|
||||||
if patch.Op != api.GroupPatchOperationOpReplace {
|
if patch.Op != api.GroupPatchOperationOpReplace {
|
||||||
http.Error(w, fmt.Sprintf("Name field only accepts replace operation, got %s", patch.Op),
|
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||||
http.StatusBadRequest)
|
"name field only accepts replace operation, got %s", patch.Op), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(patch.Value) == 0 || patch.Value[0] == "" {
|
if len(patch.Value) == 0 || patch.Value[0] == "" {
|
||||||
http.Error(w, "Group name shouldn't be empty", http.StatusUnprocessableEntity)
|
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -193,53 +197,43 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
Values: peerKeys,
|
Values: peerKeys,
|
||||||
})
|
})
|
||||||
default:
|
default:
|
||||||
http.Error(w, "invalid operation, \"%s\", for Peers field", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||||
|
"invalid operation, \"%v\", for Peers field", patch.Op), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
http.Error(w, "invalid patch path", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
group, err := h.accountManager.UpdateGroup(account.Id, groupID, operations)
|
group, err := h.accountManager.UpdateGroup(account.Id, groupID, operations)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStatus, ok := status.FromError(err)
|
util.WriteError(err, w)
|
||||||
if ok && errStatus.Code() == codes.Internal {
|
|
||||||
http.Error(w, errStatus.String(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok && errStatus.Code() == codes.NotFound {
|
|
||||||
http.Error(w, errStatus.String(), http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
|
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, toGroupResponse(account, group))
|
util.WriteJSONObject(w, toGroupResponse(account, group))
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateGroupHandler handles group creation request
|
// CreateGroupHandler handles group creation request
|
||||||
func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req api.PostApiGroupsJSONRequestBody
|
var req api.PostApiGroupsJSONRequestBody
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
if err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Name == "" {
|
if req.Name == "" {
|
||||||
http.Error(w, "Group name shouldn't be empty", http.StatusUnprocessableEntity)
|
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -249,55 +243,57 @@ func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
Peers: peerIPsToKeys(account, req.Peers),
|
Peers: peerIPsToKeys(account, req.Peers),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.accountManager.SaveGroup(account.Id, &group); err != nil {
|
err = h.accountManager.SaveGroup(account.Id, &group)
|
||||||
log.Errorf("failed creating group \"%s\" under account %s %v", req.Name, account.Id, err)
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, toGroupResponse(account, &group))
|
util.WriteJSONObject(w, toGroupResponse(account, &group))
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteGroupHandler handles group deletion request
|
// DeleteGroupHandler handles group deletion request
|
||||||
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
aID := account.Id
|
aID := account.Id
|
||||||
|
|
||||||
groupID := mux.Vars(r)["id"]
|
groupID := mux.Vars(r)["id"]
|
||||||
if len(groupID) == 0 {
|
if len(groupID) == 0 {
|
||||||
http.Error(w, "invalid group ID", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
allGroup, err := account.GetGroupAll()
|
allGroup, err := account.GetGroupAll()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if allGroup.ID == groupID {
|
if allGroup.ID == groupID {
|
||||||
http.Error(w, "deleting group ALL is not allowed", http.StatusMethodNotAllowed)
|
util.WriteError(status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.accountManager.DeleteGroup(aID, groupID); err != nil {
|
err = h.accountManager.DeleteGroup(aID, groupID)
|
||||||
log.Errorf("failed delete group %s under account %s %v", groupID, aID, err)
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, "")
|
util.WriteJSONObject(w, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGroupHandler returns a group
|
// GetGroupHandler returns a group
|
||||||
func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -305,19 +301,22 @@ func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
groupID := mux.Vars(r)["id"]
|
groupID := mux.Vars(r)["id"]
|
||||||
if len(groupID) == 0 {
|
if len(groupID) == 0 {
|
||||||
http.Error(w, "invalid group ID", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
group, err := h.accountManager.GetGroup(account.Id, groupID)
|
group, err := h.accountManager.GetGroup(account.Id, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "group not found", http.StatusNotFound)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, toGroupResponse(account, group))
|
util.WriteJSONObject(w, toGroupResponse(account, group))
|
||||||
default:
|
default:
|
||||||
http.Error(w, "", http.StatusNotFound)
|
if err != nil {
|
||||||
|
util.WriteError(status.Errorf(status.NotFound, "HTTP method not found"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -36,7 +37,7 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
|
|||||||
},
|
},
|
||||||
GetGroupFunc: func(_, groupID string) (*server.Group, error) {
|
GetGroupFunc: func(_, groupID string) (*server.Group, error) {
|
||||||
if groupID != "idofthegroup" {
|
if groupID != "idofthegroup" {
|
||||||
return nil, fmt.Errorf("not found")
|
return nil, status.Errorf(status.NotFound, "not found")
|
||||||
}
|
}
|
||||||
return &server.Group{
|
return &server.Group{
|
||||||
ID: "idofthegroup",
|
ID: "idofthegroup",
|
||||||
@ -67,7 +68,7 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
|
|||||||
}
|
}
|
||||||
return nil, fmt.Errorf("peer not found")
|
return nil, fmt.Errorf("peer not found")
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
return &server.Account{
|
return &server.Account{
|
||||||
Id: claims.AccountId,
|
Id: claims.AccountId,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
@ -78,7 +79,7 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
|
|||||||
Groups: map[string]*server.Group{
|
Groups: map[string]*server.Group{
|
||||||
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
|
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
|
||||||
"id-all": {ID: "id-all", Name: "All"}},
|
"id-all": {ID: "id-all", Name: "All"}},
|
||||||
}, nil
|
}, user, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
authAudience: "",
|
authAudience: "",
|
||||||
@ -223,7 +224,7 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
requestPath: "/api/groups/id-all",
|
requestPath: "/api/groups/id-all",
|
||||||
requestBody: bytes.NewBuffer(
|
requestBody: bytes.NewBuffer(
|
||||||
[]byte(`{"Name":"super"}`)),
|
[]byte(`{"Name":"super"}`)),
|
||||||
expectedStatus: http.StatusMethodNotAllowed,
|
expectedStatus: http.StatusUnprocessableEntity,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -244,7 +245,7 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
requestPath: "/api/groups/id-existed",
|
requestPath: "/api/groups/id-existed",
|
||||||
requestBody: bytes.NewBuffer(
|
requestBody: bytes.NewBuffer(
|
||||||
[]byte(`[{"op":"insert","path":"name","value":[""]}]`)),
|
[]byte(`[{"op":"insert","path":"name","value":[""]}]`)),
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusUnprocessableEntity,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
@ -33,14 +34,15 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
|||||||
|
|
||||||
ok, err := a.isUserAdmin(jwtClaims)
|
ok, err := a.isUserAdmin(jwtClaims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("error get user from JWT: %v", err), http.StatusUnauthorized)
|
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
|
|
||||||
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
|
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
|
||||||
http.Error(w, "user is not admin", http.StatusForbidden)
|
util.WriteError(status.Errorf(status.PermissionDenied, "only admin can perform this operation"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,12 +7,12 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
//Jwks is a collection of JSONWebKeys obtained from Config.HttpServerConfig.AuthKeysLocation
|
// Jwks is a collection of JSONWebKeys obtained from Config.HttpServerConfig.AuthKeysLocation
|
||||||
type Jwks struct {
|
type Jwks struct {
|
||||||
Keys []JSONWebKeys `json:"keys"`
|
Keys []JSONWebKeys `json:"keys"`
|
||||||
}
|
}
|
||||||
|
|
||||||
//JSONWebKeys is a representation of a Jason Web Key
|
// JSONWebKeys is a representation of a Jason Web Key
|
||||||
type JSONWebKeys struct {
|
type JSONWebKeys struct {
|
||||||
Kty string `json:"kty"`
|
Kty string `json:"kty"`
|
||||||
Kid string `json:"kid"`
|
Kid string `json:"kid"`
|
||||||
@ -22,7 +22,7 @@ type JSONWebKeys struct {
|
|||||||
X5c []string `json:"x5c"`
|
X5c []string `json:"x5c"`
|
||||||
}
|
}
|
||||||
|
|
||||||
//NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header
|
// NewJwtMiddleware creates new middleware to verify the JWT token sent via Authorization header
|
||||||
func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) {
|
func NewJwtMiddleware(issuer string, audience string, keysLocation string) (*JWTMiddleware, error) {
|
||||||
|
|
||||||
keys, err := getPemKeys(keysLocation)
|
keys, err := getPemKeys(keysLocation)
|
||||||
@ -66,7 +66,6 @@ func getPemKeys(keysLocation string) (*Jwks, error) {
|
|||||||
|
|
||||||
var jwks = &Jwks{}
|
var jwks = &Jwks{}
|
||||||
err = json.NewDecoder(resp.Body).Decode(jwks)
|
err = json.NewDecoder(resp.Body).Decode(jwks)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return jwks, err
|
return jwks, err
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@ -57,7 +59,7 @@ type JWTMiddleware struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func OnError(w http.ResponseWriter, r *http.Request, err string) {
|
func OnError(w http.ResponseWriter, r *http.Request, err string) {
|
||||||
http.Error(w, err, http.StatusUnauthorized)
|
util.WriteError(status.Errorf(status.Unauthorized, ""), w)
|
||||||
}
|
}
|
||||||
|
|
||||||
// New constructs a new Secure instance with supplied options.
|
// New constructs a new Secure instance with supplied options.
|
||||||
|
@ -7,7 +7,9 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
@ -30,7 +32,8 @@ func NewNameservers(accountManager server.AccountManager, authAudience string) *
|
|||||||
|
|
||||||
// GetAllNameserversHandler returns the list of nameserver groups for the account
|
// GetAllNameserversHandler returns the list of nameserver groups for the account
|
||||||
func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
@ -39,7 +42,7 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re
|
|||||||
|
|
||||||
nsGroups, err := h.accountManager.ListNameServerGroups(account.Id)
|
nsGroups, err := h.accountManager.ListNameServerGroups(account.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
toHTTPError(err, w)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,64 +51,67 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re
|
|||||||
apiNameservers = append(apiNameservers, toNameserverGroupResponse(r))
|
apiNameservers = append(apiNameservers, toNameserverGroupResponse(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, apiNameservers)
|
util.WriteJSONObject(w, apiNameservers)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateNameserverGroupHandler handles nameserver group creation request
|
// CreateNameserverGroupHandler handles nameserver group creation request
|
||||||
func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req api.PostApiDnsNameserversJSONRequestBody
|
var req api.PostApiDnsNameserversJSONRequestBody
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
if err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsList, err := toServerNSList(req.Nameservers)
|
nsList, err := toServerNSList(req.Nameservers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled)
|
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
toHTTPError(err, w)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toNameserverGroupResponse(nsGroup)
|
resp := toNameserverGroupResponse(nsGroup)
|
||||||
|
|
||||||
writeJSONObject(w, &resp)
|
util.WriteJSONObject(w, &resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID
|
// UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID
|
||||||
func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroupID := mux.Vars(r)["id"]
|
nsGroupID := mux.Vars(r)["id"]
|
||||||
if len(nsGroupID) == 0 {
|
if len(nsGroupID) == 0 {
|
||||||
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req api.PutApiDnsNameserversIdJSONRequestBody
|
var req api.PutApiDnsNameserversIdJSONRequestBody
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
if err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsList, err := toServerNSList(req.Nameservers)
|
nsList, err := toServerNSList(req.Nameservers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,41 +128,42 @@ func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *htt
|
|||||||
|
|
||||||
err = h.accountManager.SaveNameServerGroup(account.Id, updatedNSGroup)
|
err = h.accountManager.SaveNameServerGroup(account.Id, updatedNSGroup)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
toHTTPError(err, w)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toNameserverGroupResponse(updatedNSGroup)
|
resp := toNameserverGroupResponse(updatedNSGroup)
|
||||||
|
|
||||||
writeJSONObject(w, &resp)
|
util.WriteJSONObject(w, &resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID
|
// PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID
|
||||||
func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroupID := mux.Vars(r)["id"]
|
nsGroupID := mux.Vars(r)["id"]
|
||||||
if len(nsGroupID) == 0 {
|
if len(nsGroupID) == 0 {
|
||||||
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req api.PatchApiDnsNameserversIdJSONRequestBody
|
var req api.PatchApiDnsNameserversIdJSONRequestBody
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
if err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var operations []server.NameServerGroupUpdateOperation
|
var operations []server.NameServerGroupUpdateOperation
|
||||||
for _, patch := range req {
|
for _, patch := range req {
|
||||||
if patch.Op != api.NameserverGroupPatchOperationOpReplace {
|
if patch.Op != api.NameserverGroupPatchOperationOpReplace {
|
||||||
http.Error(w, fmt.Sprintf("nameserver groups only accepts replace operations, got %s", patch.Op),
|
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||||
http.StatusBadRequest)
|
"nameserver groups only accepts replace operations, got %s", patch.Op), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switch patch.Path {
|
switch patch.Path {
|
||||||
@ -196,49 +203,50 @@ func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http
|
|||||||
Values: patch.Value,
|
Values: patch.Value,
|
||||||
})
|
})
|
||||||
default:
|
default:
|
||||||
http.Error(w, "invalid patch path", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
updatedNSGroup, err := h.accountManager.UpdateNameServerGroup(account.Id, nsGroupID, operations)
|
updatedNSGroup, err := h.accountManager.UpdateNameServerGroup(account.Id, nsGroupID, operations)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
toHTTPError(err, w)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toNameserverGroupResponse(updatedNSGroup)
|
resp := toNameserverGroupResponse(updatedNSGroup)
|
||||||
|
|
||||||
writeJSONObject(w, &resp)
|
util.WriteJSONObject(w, &resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteNameserverGroupHandler handles nameserver group deletion request
|
// DeleteNameserverGroupHandler handles nameserver group deletion request
|
||||||
func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroupID := mux.Vars(r)["id"]
|
nsGroupID := mux.Vars(r)["id"]
|
||||||
if len(nsGroupID) == 0 {
|
if len(nsGroupID) == 0 {
|
||||||
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID)
|
err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
toHTTPError(err, w)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, "")
|
util.WriteJSONObject(w, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNameserverGroupHandler handles a nameserver group Get request identified by ID
|
// GetNameserverGroupHandler handles a nameserver group Get request identified by ID
|
||||||
func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
@ -247,19 +255,19 @@ func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.R
|
|||||||
|
|
||||||
nsGroupID := mux.Vars(r)["id"]
|
nsGroupID := mux.Vars(r)["id"]
|
||||||
if len(nsGroupID) == 0 {
|
if len(nsGroupID) == 0 {
|
||||||
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, nsGroupID)
|
nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, nsGroupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
toHTTPError(err, w)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toNameserverGroupResponse(nsGroup)
|
resp := toNameserverGroupResponse(nsGroup)
|
||||||
|
|
||||||
writeJSONObject(w, &resp)
|
util.WriteJSONObject(w, &resp)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,9 +5,8 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@ -62,7 +61,7 @@ func initNameserversTestData() *Nameservers {
|
|||||||
if nsGroupID == existingNSGroupID {
|
if nsGroupID == existingNSGroupID {
|
||||||
return baseExistingNSGroup.Copy(), nil
|
return baseExistingNSGroup.Copy(), nil
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID)
|
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
|
||||||
},
|
},
|
||||||
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
|
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) {
|
||||||
return &nbdns.NameServerGroup{
|
return &nbdns.NameServerGroup{
|
||||||
@ -83,12 +82,12 @@ func initNameserversTestData() *Nameservers {
|
|||||||
if nsGroupToSave.ID == existingNSGroupID {
|
if nsGroupToSave.ID == existingNSGroupID {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
|
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
|
||||||
},
|
},
|
||||||
UpdateNameServerGroupFunc: func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
|
UpdateNameServerGroupFunc: func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
|
||||||
nsGroupToUpdate := baseExistingNSGroup.Copy()
|
nsGroupToUpdate := baseExistingNSGroup.Copy()
|
||||||
if nsGroupID != nsGroupToUpdate.ID {
|
if nsGroupID != nsGroupToUpdate.ID {
|
||||||
return nil, status.Errorf(codes.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
|
return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
|
||||||
}
|
}
|
||||||
for _, operation := range operations {
|
for _, operation := range operations {
|
||||||
switch operation.Type {
|
switch operation.Type {
|
||||||
@ -110,8 +109,8 @@ func initNameserversTestData() *Nameservers {
|
|||||||
}
|
}
|
||||||
return nsGroupToUpdate, nil
|
return nsGroupToUpdate, nil
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
return testingNSAccount, nil
|
return testingNSAccount, nil, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
authAudience: "",
|
authAudience: "",
|
||||||
@ -181,7 +180,7 @@ func TestNameserversHandlers(t *testing.T) {
|
|||||||
requestPath: "/api/dns/nameservers",
|
requestPath: "/api/dns/nameservers",
|
||||||
requestBody: bytes.NewBuffer(
|
requestBody: bytes.NewBuffer(
|
||||||
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
|
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusUnprocessableEntity,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -223,7 +222,7 @@ func TestNameserversHandlers(t *testing.T) {
|
|||||||
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
|
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
|
||||||
requestBody: bytes.NewBuffer(
|
requestBody: bytes.NewBuffer(
|
||||||
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
|
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true,\"primary\":true}")),
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusUnprocessableEntity,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -6,8 +6,9 @@ import (
|
|||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
log "github.com/sirupsen/logrus"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -28,50 +29,47 @@ func NewPeers(accountManager server.AccountManager, authAudience string) *Peers
|
|||||||
|
|
||||||
func (h *Peers) updatePeer(account *server.Account, peer *server.Peer, w http.ResponseWriter, r *http.Request) {
|
func (h *Peers) updatePeer(account *server.Account, peer *server.Peer, w http.ResponseWriter, r *http.Request) {
|
||||||
req := &api.PutApiPeersIdJSONBody{}
|
req := &api.PutApiPeersIdJSONBody{}
|
||||||
peerIp := peer.IP
|
|
||||||
err := json.NewDecoder(r.Body).Decode(&req)
|
err := json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
update := &server.Peer{Key: peer.Key, SSHEnabled: req.SshEnabled, Name: req.Name}
|
update := &server.Peer{Key: peer.Key, SSHEnabled: req.SshEnabled, Name: req.Name}
|
||||||
peer, err = h.accountManager.UpdatePeer(account.Id, update)
|
peer, err = h.accountManager.UpdatePeer(account.Id, update)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed updating peer %s under account %s %v", peerIp, account.Id, err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
writeJSONObject(w, toPeerResponse(peer, account))
|
util.WriteJSONObject(w, toPeerResponse(peer, account))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseWriter, r *http.Request) {
|
func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseWriter, r *http.Request) {
|
||||||
_, err := h.accountManager.DeletePeer(accountId, peer.Key)
|
_, err := h.accountManager.DeletePeer(accountId, peer.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed deleteing peer %s, %v", peer.IP, err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
writeJSONObject(w, "")
|
util.WriteJSONObject(w, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
peerId := vars["id"] //effectively peer IP address
|
peerId := vars["id"] //effectively peer IP address
|
||||||
if len(peerId) == 0 {
|
if len(peerId) == 0 {
|
||||||
http.Error(w, "invalid peer Id", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, err := h.accountManager.GetPeerByIP(account.Id, peerId)
|
peer, err := h.accountManager.GetPeerByIP(account.Id, peerId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "peer not found", http.StatusNotFound)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -83,11 +81,11 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
|||||||
h.updatePeer(account, peer, w, r)
|
h.updatePeer(account, peer, w, r)
|
||||||
return
|
return
|
||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
writeJSONObject(w, toPeerResponse(peer, account))
|
util.WriteJSONObject(w, toPeerResponse(peer, account))
|
||||||
return
|
return
|
||||||
|
|
||||||
default:
|
default:
|
||||||
http.Error(w, "", http.StatusNotFound)
|
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -95,15 +93,16 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
|||||||
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
|
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
peers, err := h.accountManager.GetPeers(account.Id, user.Id)
|
peers, err := h.accountManager.GetPeers(account.Id, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -111,10 +110,10 @@ func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
respBody = append(respBody, toPeerResponse(peer, account))
|
respBody = append(respBody, toPeerResponse(peer, account))
|
||||||
}
|
}
|
||||||
writeJSONObject(w, respBody)
|
util.WriteJSONObject(w, respBody)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
http.Error(w, "", http.StatusNotFound)
|
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,7 +22,8 @@ func initTestMetaData(peers ...*server.Peer) *Peers {
|
|||||||
GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) {
|
GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) {
|
||||||
return peers, nil
|
return peers, nil
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
|
user := server.NewAdminUser("test_user")
|
||||||
return &server.Account{
|
return &server.Account{
|
||||||
Id: claims.AccountId,
|
Id: claims.AccountId,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
@ -30,9 +31,9 @@ func initTestMetaData(peers ...*server.Peer) *Peers {
|
|||||||
"test_peer": peers[0],
|
"test_peer": peers[0],
|
||||||
},
|
},
|
||||||
Users: map[string]*server.User{
|
Users: map[string]*server.User{
|
||||||
"test_user": server.NewAdminUser("test_user"),
|
"test_user": user,
|
||||||
},
|
},
|
||||||
}, nil
|
}, user, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
authAudience: "",
|
authAudience: "",
|
||||||
|
@ -2,15 +2,13 @@ package http
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
@ -33,25 +31,16 @@ func NewRoutes(accountManager server.AccountManager, authAudience string) *Route
|
|||||||
|
|
||||||
// GetAllRoutesHandler returns the list of routes for the account
|
// GetAllRoutesHandler returns the list of routes for the account
|
||||||
func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
routes, err := h.accountManager.ListRoutes(account.Id, user.Id)
|
routes, err := h.accountManager.ListRoutes(account.Id, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
if e, ok := server.FromError(err); ok {
|
|
||||||
switch e.Type() {
|
|
||||||
case server.PermissionDenied:
|
|
||||||
http.Error(w, e.Error(), http.StatusForbidden)
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
apiRoutes := make([]*api.Route, 0)
|
apiRoutes := make([]*api.Route, 0)
|
||||||
@ -59,20 +48,22 @@ func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
apiRoutes = append(apiRoutes, toRouteResponse(account, r))
|
apiRoutes = append(apiRoutes, toRouteResponse(account, r))
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, apiRoutes)
|
util.WriteJSONObject(w, apiRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateRouteHandler handles route creation request
|
// CreateRouteHandler handles route creation request
|
||||||
func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req api.PostApiRoutesJSONRequestBody
|
var req api.PostApiRoutesJSONRequestBody
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
if err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -80,8 +71,7 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
if req.Peer != "" {
|
if req.Peer != "" {
|
||||||
peer, err := h.accountManager.GetPeerByIP(account.Id, req.Peer)
|
peer, err := h.accountManager.GetPeerByIP(account.Id, req.Peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusUnprocessableEntity)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peerKey = peer.Key
|
peerKey = peer.Key
|
||||||
@ -89,57 +79,60 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
_, newPrefix, err := route.ParseNetwork(req.Network)
|
_, newPrefix, err := route.ParseNetwork(req.Network)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("couldn't parse update prefix %s", req.Network), http.StatusBadRequest)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
|
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
|
||||||
http.Error(w, fmt.Sprintf("identifier should be between 1 and %d", route.MaxNetIDChar), http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d",
|
||||||
|
route.MaxNetIDChar), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix.String(), peerKey, req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Enabled)
|
newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix.String(), peerKey, req.Description, req.NetworkId, req.Masquerade, req.Metric, req.Enabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toRouteResponse(account, newRoute)
|
resp := toRouteResponse(account, newRoute)
|
||||||
|
|
||||||
writeJSONObject(w, &resp)
|
util.WriteJSONObject(w, &resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRouteHandler handles update to a route identified by a given ID
|
// UpdateRouteHandler handles update to a route identified by a given ID
|
||||||
func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
routeID := vars["id"]
|
routeID := vars["id"]
|
||||||
if len(routeID) == 0 {
|
if len(routeID) == 0 {
|
||||||
http.Error(w, "invalid route Id", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
|
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("couldn't find route for ID %s", routeID), http.StatusNotFound)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req api.PutApiRoutesIdJSONRequestBody
|
var req api.PutApiRoutesIdJSONRequestBody
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
if err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
prefixType, newPrefix, err := route.ParseNetwork(req.Network)
|
prefixType, newPrefix, err := route.ParseNetwork(req.Network)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("couldn't parse update prefix %s for route ID %s", req.Network, routeID), http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "couldn't parse update prefix %s for route ID %s",
|
||||||
|
req.Network, routeID), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -147,15 +140,15 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
if req.Peer != "" {
|
if req.Peer != "" {
|
||||||
peer, err := h.accountManager.GetPeerByIP(account.Id, req.Peer)
|
peer, err := h.accountManager.GetPeerByIP(account.Id, req.Peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusUnprocessableEntity)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peerKey = peer.Key
|
peerKey = peer.Key
|
||||||
}
|
}
|
||||||
|
|
||||||
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
|
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
|
||||||
http.Error(w, fmt.Sprintf("identifier should be between 1 and %d", route.MaxNetIDChar), http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||||
|
"identifier should be between 1 and %d", route.MaxNetIDChar), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -173,46 +166,46 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
err = h.accountManager.SaveRoute(account.Id, newRoute)
|
err = h.accountManager.SaveRoute(account.Id, newRoute)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed updating route \"%s\" under account %s %v", routeID, account.Id, err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toRouteResponse(account, newRoute)
|
resp := toRouteResponse(account, newRoute)
|
||||||
|
|
||||||
writeJSONObject(w, &resp)
|
util.WriteJSONObject(w, &resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PatchRouteHandler handles patch updates to a route identified by a given ID
|
// PatchRouteHandler handles patch updates to a route identified by a given ID
|
||||||
func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
routeID := vars["id"]
|
routeID := vars["id"]
|
||||||
if len(routeID) == 0 {
|
if len(routeID) == 0 {
|
||||||
http.Error(w, "invalid route ID", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
|
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Error(w, fmt.Sprintf("couldn't find route ID %s", routeID), http.StatusNotFound)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req api.PatchApiRoutesIdJSONRequestBody
|
var req api.PatchApiRoutesIdJSONRequestBody
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
if err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(req) == 0 {
|
if len(req) == 0 {
|
||||||
http.Error(w, "no patch instruction received", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -222,8 +215,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
switch patch.Path {
|
switch patch.Path {
|
||||||
case api.RoutePatchOperationPathNetwork:
|
case api.RoutePatchOperationPathNetwork:
|
||||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
if patch.Op != api.RoutePatchOperationOpReplace {
|
||||||
http.Error(w, fmt.Sprintf("Network field only accepts replace operation, got %s", patch.Op),
|
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||||
http.StatusBadRequest)
|
"network field only accepts replace operation, got %s", patch.Op), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
operations = append(operations, server.RouteUpdateOperation{
|
operations = append(operations, server.RouteUpdateOperation{
|
||||||
@ -232,8 +225,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
})
|
})
|
||||||
case api.RoutePatchOperationPathDescription:
|
case api.RoutePatchOperationPathDescription:
|
||||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
if patch.Op != api.RoutePatchOperationOpReplace {
|
||||||
http.Error(w, fmt.Sprintf("Description field only accepts replace operation, got %s", patch.Op),
|
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||||
http.StatusBadRequest)
|
"description field only accepts replace operation, got %s", patch.Op), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
operations = append(operations, server.RouteUpdateOperation{
|
operations = append(operations, server.RouteUpdateOperation{
|
||||||
@ -242,8 +235,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
})
|
})
|
||||||
case api.RoutePatchOperationPathNetworkId:
|
case api.RoutePatchOperationPathNetworkId:
|
||||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
if patch.Op != api.RoutePatchOperationOpReplace {
|
||||||
http.Error(w, fmt.Sprintf("Network Identifier field only accepts replace operation, got %s", patch.Op),
|
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||||
http.StatusBadRequest)
|
"network Identifier field only accepts replace operation, got %s", patch.Op), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
operations = append(operations, server.RouteUpdateOperation{
|
operations = append(operations, server.RouteUpdateOperation{
|
||||||
@ -252,21 +245,20 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
})
|
})
|
||||||
case api.RoutePatchOperationPathPeer:
|
case api.RoutePatchOperationPathPeer:
|
||||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
if patch.Op != api.RoutePatchOperationOpReplace {
|
||||||
http.Error(w, fmt.Sprintf("Peer field only accepts replace operation, got %s", patch.Op),
|
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||||
http.StatusBadRequest)
|
"peer field only accepts replace operation, got %s", patch.Op), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(patch.Value) > 1 {
|
if len(patch.Value) > 1 {
|
||||||
http.Error(w, fmt.Sprintf("Value field only accepts 1 value, got %d", len(patch.Value)),
|
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||||
http.StatusBadRequest)
|
"value field only accepts 1 value, got %d", len(patch.Value)), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peerValue := patch.Value
|
peerValue := patch.Value
|
||||||
if patch.Value[0] != "" {
|
if patch.Value[0] != "" {
|
||||||
peer, err := h.accountManager.GetPeerByIP(account.Id, patch.Value[0])
|
peer, err := h.accountManager.GetPeerByIP(account.Id, patch.Value[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusUnprocessableEntity)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
peerValue = []string{peer.Key}
|
peerValue = []string{peer.Key}
|
||||||
@ -277,8 +269,9 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
})
|
})
|
||||||
case api.RoutePatchOperationPathMetric:
|
case api.RoutePatchOperationPathMetric:
|
||||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
if patch.Op != api.RoutePatchOperationOpReplace {
|
||||||
http.Error(w, fmt.Sprintf("Metric field only accepts replace operation, got %s", patch.Op),
|
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||||
http.StatusBadRequest)
|
"metric field only accepts replace operation, got %s", patch.Op), w)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
operations = append(operations, server.RouteUpdateOperation{
|
operations = append(operations, server.RouteUpdateOperation{
|
||||||
@ -287,8 +280,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
})
|
})
|
||||||
case api.RoutePatchOperationPathMasquerade:
|
case api.RoutePatchOperationPathMasquerade:
|
||||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
if patch.Op != api.RoutePatchOperationOpReplace {
|
||||||
http.Error(w, fmt.Sprintf("Masquerade field only accepts replace operation, got %s", patch.Op),
|
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||||
http.StatusBadRequest)
|
"masquerade field only accepts replace operation, got %s", patch.Op), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
operations = append(operations, server.RouteUpdateOperation{
|
operations = append(operations, server.RouteUpdateOperation{
|
||||||
@ -297,8 +290,8 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
})
|
})
|
||||||
case api.RoutePatchOperationPathEnabled:
|
case api.RoutePatchOperationPathEnabled:
|
||||||
if patch.Op != api.RoutePatchOperationOpReplace {
|
if patch.Op != api.RoutePatchOperationOpReplace {
|
||||||
http.Error(w, fmt.Sprintf("Enabled field only accepts replace operation, got %s", patch.Op),
|
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||||
http.StatusBadRequest)
|
"enabled field only accepts replace operation, got %s", patch.Op), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
operations = append(operations, server.RouteUpdateOperation{
|
operations = append(operations, server.RouteUpdateOperation{
|
||||||
@ -306,90 +299,68 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
Values: patch.Value,
|
Values: patch.Value,
|
||||||
})
|
})
|
||||||
default:
|
default:
|
||||||
http.Error(w, "invalid patch path", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
route, err := h.accountManager.UpdateRoute(account.Id, routeID, operations)
|
route, err := h.accountManager.UpdateRoute(account.Id, routeID, operations)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStatus, ok := status.FromError(err)
|
util.WriteError(err, w)
|
||||||
if ok && errStatus.Code() == codes.Internal {
|
|
||||||
http.Error(w, errStatus.String(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok && errStatus.Code() == codes.NotFound {
|
|
||||||
http.Error(w, errStatus.String(), http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok && errStatus.Code() == codes.InvalidArgument {
|
|
||||||
http.Error(w, errStatus.String(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Errorf("failed updating route %s under account %s %v", routeID, account.Id, err)
|
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toRouteResponse(account, route)
|
resp := toRouteResponse(account, route)
|
||||||
|
|
||||||
writeJSONObject(w, &resp)
|
util.WriteJSONObject(w, &resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRouteHandler handles route deletion request
|
// DeleteRouteHandler handles route deletion request
|
||||||
func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
routeID := mux.Vars(r)["id"]
|
routeID := mux.Vars(r)["id"]
|
||||||
if len(routeID) == 0 {
|
if len(routeID) == 0 {
|
||||||
http.Error(w, "invalid route ID", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.accountManager.DeleteRoute(account.Id, routeID)
|
err = h.accountManager.DeleteRoute(account.Id, routeID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStatus, ok := status.FromError(err)
|
util.WriteError(err, w)
|
||||||
if ok && errStatus.Code() == codes.NotFound {
|
|
||||||
http.Error(w, fmt.Sprintf("route %s not found under account %s", routeID, account.Id), http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Errorf("failed delete route %s under account %s %v", routeID, account.Id, err)
|
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, "")
|
util.WriteJSONObject(w, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRouteHandler handles a route Get request identified by ID
|
// GetRouteHandler handles a route Get request identified by ID
|
||||||
func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
routeID := mux.Vars(r)["id"]
|
routeID := mux.Vars(r)["id"]
|
||||||
if len(routeID) == 0 {
|
if len(routeID) == 0 {
|
||||||
http.Error(w, "invalid route ID", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id)
|
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "route not found", http.StatusNotFound)
|
util.WriteError(status.Errorf(status.NotFound, "route not found"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, toRouteResponse(account, foundRoute))
|
util.WriteJSONObject(w, toRouteResponse(account, foundRoute))
|
||||||
}
|
}
|
||||||
|
|
||||||
func toRouteResponse(account *server.Account, serverRoute *route.Route) *api.Route {
|
func toRouteResponse(account *server.Account, serverRoute *route.Route) *api.Route {
|
||||||
|
@ -5,9 +5,8 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@ -63,7 +62,7 @@ func initRoutesTestData() *Routes {
|
|||||||
if routeID == existingRouteID {
|
if routeID == existingRouteID {
|
||||||
return baseExistingRoute, nil
|
return baseExistingRoute, nil
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.NotFound, "route with ID %s not found", routeID)
|
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
||||||
},
|
},
|
||||||
CreateRouteFunc: func(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) {
|
CreateRouteFunc: func(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) {
|
||||||
networkType, p, _ := route.ParseNetwork(network)
|
networkType, p, _ := route.ParseNetwork(network)
|
||||||
@ -83,13 +82,13 @@ func initRoutesTestData() *Routes {
|
|||||||
},
|
},
|
||||||
DeleteRouteFunc: func(_ string, peerIP string) error {
|
DeleteRouteFunc: func(_ string, peerIP string) error {
|
||||||
if peerIP != existingRouteID {
|
if peerIP != existingRouteID {
|
||||||
return status.Errorf(codes.NotFound, "Peer with ID %s not found", peerIP)
|
return status.Errorf(status.NotFound, "Peer with ID %s not found", peerIP)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
|
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
|
||||||
if peerIP != existingPeerID {
|
if peerIP != existingPeerID {
|
||||||
return nil, status.Errorf(codes.NotFound, "Peer with ID %s not found", peerIP)
|
return nil, status.Errorf(status.NotFound, "Peer with ID %s not found", peerIP)
|
||||||
}
|
}
|
||||||
return &server.Peer{
|
return &server.Peer{
|
||||||
Key: existingPeerKey,
|
Key: existingPeerKey,
|
||||||
@ -99,7 +98,7 @@ func initRoutesTestData() *Routes {
|
|||||||
UpdateRouteFunc: func(_ string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) {
|
UpdateRouteFunc: func(_ string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) {
|
||||||
routeToUpdate := baseExistingRoute
|
routeToUpdate := baseExistingRoute
|
||||||
if routeID != routeToUpdate.ID {
|
if routeID != routeToUpdate.ID {
|
||||||
return nil, status.Errorf(codes.NotFound, "route %s no longer exists", routeID)
|
return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID)
|
||||||
}
|
}
|
||||||
for _, operation := range operations {
|
for _, operation := range operations {
|
||||||
switch operation.Type {
|
switch operation.Type {
|
||||||
@ -123,8 +122,8 @@ func initRoutesTestData() *Routes {
|
|||||||
}
|
}
|
||||||
return routeToUpdate, nil
|
return routeToUpdate, nil
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
return testingAccount, nil
|
return testingAccount, testingAccount.Users["test_user"], nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
authAudience: "",
|
authAudience: "",
|
||||||
@ -201,15 +200,15 @@ func TestRoutesHandlers(t *testing.T) {
|
|||||||
requestType: http.MethodPost,
|
requestType: http.MethodPost,
|
||||||
requestPath: "/api/routes",
|
requestPath: "/api/routes",
|
||||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", notFoundPeerID)),
|
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", notFoundPeerID)),
|
||||||
expectedStatus: http.StatusUnprocessableEntity,
|
expectedStatus: http.StatusNotFound,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "POST Not Invalid Network Identifier",
|
name: "POST Invalid Network Identifier",
|
||||||
requestType: http.MethodPost,
|
requestType: http.MethodPost,
|
||||||
requestPath: "/api/routes",
|
requestPath: "/api/routes",
|
||||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"12345678901234567890qwertyuiopqwertyuiop1\",\"Peer\":\"%s\"}", existingPeerID)),
|
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"12345678901234567890qwertyuiopqwertyuiop1\",\"Peer\":\"%s\"}", existingPeerID)),
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusUnprocessableEntity,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -217,7 +216,7 @@ func TestRoutesHandlers(t *testing.T) {
|
|||||||
requestType: http.MethodPost,
|
requestType: http.MethodPost,
|
||||||
requestPath: "/api/routes",
|
requestPath: "/api/routes",
|
||||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/34\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", existingPeerID)),
|
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/34\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", existingPeerID)),
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusUnprocessableEntity,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -251,7 +250,7 @@ func TestRoutesHandlers(t *testing.T) {
|
|||||||
requestType: http.MethodPut,
|
requestType: http.MethodPut,
|
||||||
requestPath: "/api/routes/" + existingRouteID,
|
requestPath: "/api/routes/" + existingRouteID,
|
||||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", notFoundPeerID)),
|
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", notFoundPeerID)),
|
||||||
expectedStatus: http.StatusUnprocessableEntity,
|
expectedStatus: http.StatusNotFound,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -259,7 +258,7 @@ func TestRoutesHandlers(t *testing.T) {
|
|||||||
requestType: http.MethodPut,
|
requestType: http.MethodPut,
|
||||||
requestPath: "/api/routes/" + existingRouteID,
|
requestPath: "/api/routes/" + existingRouteID,
|
||||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"12345678901234567890qwertyuiopqwertyuiop1\",\"Peer\":\"%s\"}", existingPeerID)),
|
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"12345678901234567890qwertyuiopqwertyuiop1\",\"Peer\":\"%s\"}", existingPeerID)),
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusUnprocessableEntity,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -267,7 +266,7 @@ func TestRoutesHandlers(t *testing.T) {
|
|||||||
requestType: http.MethodPut,
|
requestType: http.MethodPut,
|
||||||
requestPath: "/api/routes/" + existingRouteID,
|
requestPath: "/api/routes/" + existingRouteID,
|
||||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/34\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", existingPeerID)),
|
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/34\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\"}", existingPeerID)),
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusUnprocessableEntity,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -312,7 +311,7 @@ func TestRoutesHandlers(t *testing.T) {
|
|||||||
requestType: http.MethodPatch,
|
requestType: http.MethodPatch,
|
||||||
requestPath: "/api/routes/" + existingRouteID,
|
requestPath: "/api/routes/" + existingRouteID,
|
||||||
requestBody: bytes.NewBufferString(fmt.Sprintf("[{\"op\":\"replace\",\"path\":\"peer\",\"value\":[\"%s\"]}]", notFoundPeerID)),
|
requestBody: bytes.NewBufferString(fmt.Sprintf("[{\"op\":\"replace\",\"path\":\"peer\",\"value\":[\"%s\"]}]", notFoundPeerID)),
|
||||||
expectedStatus: http.StatusUnprocessableEntity,
|
expectedStatus: http.StatusNotFound,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -2,15 +2,13 @@ package http
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -31,25 +29,16 @@ func NewRules(accountManager server.AccountManager, authAudience string) *Rules
|
|||||||
|
|
||||||
// GetAllRulesHandler list for the account
|
// GetAllRulesHandler list for the account
|
||||||
func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accountRules, err := h.accountManager.ListRules(account.Id, user.Id)
|
accountRules, err := h.accountManager.ListRules(account.Id, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
if e, ok := server.FromError(err); ok {
|
|
||||||
switch e.Type() {
|
|
||||||
case server.PermissionDenied:
|
|
||||||
http.Error(w, e.Error(), http.StatusForbidden)
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
rules := []*api.Rule{}
|
rules := []*api.Rule{}
|
||||||
@ -57,38 +46,39 @@ func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
rules = append(rules, toRuleResponse(account, r))
|
rules = append(rules, toRuleResponse(account, r))
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, rules)
|
util.WriteJSONObject(w, rules)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRuleHandler handles update to a rule identified by a given ID
|
// UpdateRuleHandler handles update to a rule identified by a given ID
|
||||||
func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
ruleID := vars["id"]
|
ruleID := vars["id"]
|
||||||
if len(ruleID) == 0 {
|
if len(ruleID) == 0 {
|
||||||
http.Error(w, "invalid rule Id", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok := account.Rules[ruleID]
|
_, ok := account.Rules[ruleID]
|
||||||
if !ok {
|
if !ok {
|
||||||
http.Error(w, fmt.Sprintf("couldn't find rule id %s", ruleID), http.StatusNotFound)
|
util.WriteError(status.Errorf(status.NotFound, "couldn't find rule id %s", ruleID), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req api.PutApiRulesIdJSONRequestBody
|
var req api.PutApiRulesIdJSONRequestBody
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
if err != nil {
|
||||||
return
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Name == "" {
|
if req.Name == "" {
|
||||||
http.Error(w, "Rule name shouldn't be empty", http.StatusUnprocessableEntity)
|
util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -115,50 +105,52 @@ func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
case server.TrafficFlowBidirectString:
|
case server.TrafficFlowBidirectString:
|
||||||
rule.Flow = server.TrafficFlowBidirect
|
rule.Flow = server.TrafficFlowBidirect
|
||||||
default:
|
default:
|
||||||
http.Error(w, "unknown flow type", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "unknown flow type"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.accountManager.SaveRule(account.Id, &rule); err != nil {
|
err = h.accountManager.SaveRule(account.Id, &rule)
|
||||||
log.Errorf("failed updating rule \"%s\" under account %s %v", ruleID, account.Id, err)
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toRuleResponse(account, &rule)
|
resp := toRuleResponse(account, &rule)
|
||||||
|
|
||||||
writeJSONObject(w, &resp)
|
util.WriteJSONObject(w, &resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PatchRuleHandler handles patch updates to a rule identified by a given ID
|
// PatchRuleHandler handles patch updates to a rule identified by a given ID
|
||||||
func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
ruleID := vars["id"]
|
ruleID := vars["id"]
|
||||||
if len(ruleID) == 0 {
|
if len(ruleID) == 0 {
|
||||||
http.Error(w, "invalid rule Id", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok := account.Rules[ruleID]
|
_, ok := account.Rules[ruleID]
|
||||||
if !ok {
|
if !ok {
|
||||||
http.Error(w, fmt.Sprintf("couldn't find rule id %s", ruleID), http.StatusNotFound)
|
util.WriteError(status.Errorf(status.NotFound, "couldn't find rule ID %s", ruleID), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req api.PatchApiRulesIdJSONRequestBody
|
var req api.PatchApiRulesIdJSONRequestBody
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
if err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(req) == 0 {
|
if len(req) == 0 {
|
||||||
http.Error(w, "no patch instruction received", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,12 +160,12 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
switch patch.Path {
|
switch patch.Path {
|
||||||
case api.RulePatchOperationPathName:
|
case api.RulePatchOperationPathName:
|
||||||
if patch.Op != api.RulePatchOperationOpReplace {
|
if patch.Op != api.RulePatchOperationOpReplace {
|
||||||
http.Error(w, fmt.Sprintf("Name field only accepts replace operation, got %s", patch.Op),
|
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||||
http.StatusBadRequest)
|
"name field only accepts replace operation, got %s", patch.Op), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(patch.Value) == 0 || patch.Value[0] == "" {
|
if len(patch.Value) == 0 || patch.Value[0] == "" {
|
||||||
http.Error(w, "Rule name shouldn't be empty", http.StatusUnprocessableEntity)
|
util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
operations = append(operations, server.RuleUpdateOperation{
|
operations = append(operations, server.RuleUpdateOperation{
|
||||||
@ -182,8 +174,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
})
|
})
|
||||||
case api.RulePatchOperationPathDescription:
|
case api.RulePatchOperationPathDescription:
|
||||||
if patch.Op != api.RulePatchOperationOpReplace {
|
if patch.Op != api.RulePatchOperationOpReplace {
|
||||||
http.Error(w, fmt.Sprintf("Description field only accepts replace operation, got %s", patch.Op),
|
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||||
http.StatusBadRequest)
|
"description field only accepts replace operation, got %s", patch.Op), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
operations = append(operations, server.RuleUpdateOperation{
|
operations = append(operations, server.RuleUpdateOperation{
|
||||||
@ -192,8 +184,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
})
|
})
|
||||||
case api.RulePatchOperationPathFlow:
|
case api.RulePatchOperationPathFlow:
|
||||||
if patch.Op != api.RulePatchOperationOpReplace {
|
if patch.Op != api.RulePatchOperationOpReplace {
|
||||||
http.Error(w, fmt.Sprintf("Flow field only accepts replace operation, got %s", patch.Op),
|
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||||
http.StatusBadRequest)
|
"flow field only accepts replace operation, got %s", patch.Op), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
operations = append(operations, server.RuleUpdateOperation{
|
operations = append(operations, server.RuleUpdateOperation{
|
||||||
@ -202,8 +194,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
})
|
})
|
||||||
case api.RulePatchOperationPathDisabled:
|
case api.RulePatchOperationPathDisabled:
|
||||||
if patch.Op != api.RulePatchOperationOpReplace {
|
if patch.Op != api.RulePatchOperationOpReplace {
|
||||||
http.Error(w, fmt.Sprintf("Disabled field only accepts replace operation, got %s", patch.Op),
|
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||||
http.StatusBadRequest)
|
"disabled field only accepts replace operation, got %s", patch.Op), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
operations = append(operations, server.RuleUpdateOperation{
|
operations = append(operations, server.RuleUpdateOperation{
|
||||||
@ -228,7 +220,8 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
Values: patch.Value,
|
Values: patch.Value,
|
||||||
})
|
})
|
||||||
default:
|
default:
|
||||||
http.Error(w, "invalid operation, \"%s\", for Source field", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||||
|
"invalid operation \"%s\" on Source field", patch.Op), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case api.RulePatchOperationPathDestinations:
|
case api.RulePatchOperationPathDestinations:
|
||||||
@ -249,11 +242,12 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
Values: patch.Value,
|
Values: patch.Value,
|
||||||
})
|
})
|
||||||
default:
|
default:
|
||||||
http.Error(w, "invalid operation, \"%s\", for Destination field", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument,
|
||||||
|
"invalid operation \"%s\" on Destination field", patch.Op), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
http.Error(w, "invalid patch path", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -261,48 +255,33 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
rule, err := h.accountManager.UpdateRule(account.Id, ruleID, operations)
|
rule, err := h.accountManager.UpdateRule(account.Id, ruleID, operations)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStatus, ok := status.FromError(err)
|
util.WriteError(err, w)
|
||||||
if ok && errStatus.Code() == codes.Internal {
|
|
||||||
http.Error(w, errStatus.String(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok && errStatus.Code() == codes.NotFound {
|
|
||||||
http.Error(w, errStatus.String(), http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok && errStatus.Code() == codes.InvalidArgument {
|
|
||||||
http.Error(w, errStatus.String(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Errorf("failed updating rule %s under account %s %v", ruleID, account.Id, err)
|
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toRuleResponse(account, rule)
|
resp := toRuleResponse(account, rule)
|
||||||
|
|
||||||
writeJSONObject(w, &resp)
|
util.WriteJSONObject(w, &resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateRuleHandler handles rule creation request
|
// CreateRuleHandler handles rule creation request
|
||||||
func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req api.PostApiRulesJSONRequestBody
|
var req api.PostApiRulesJSONRequestBody
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
if err != nil {
|
||||||
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Name == "" {
|
if req.Name == "" {
|
||||||
http.Error(w, "Rule name shouldn't be empty", http.StatusUnprocessableEntity)
|
util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -329,50 +308,52 @@ func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
case server.TrafficFlowBidirectString:
|
case server.TrafficFlowBidirectString:
|
||||||
rule.Flow = server.TrafficFlowBidirect
|
rule.Flow = server.TrafficFlowBidirect
|
||||||
default:
|
default:
|
||||||
http.Error(w, "unknown flow type", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "unknown flow type"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.accountManager.SaveRule(account.Id, &rule); err != nil {
|
err = h.accountManager.SaveRule(account.Id, &rule)
|
||||||
log.Errorf("failed creating rule \"%s\" under account %s %v", req.Name, account.Id, err)
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toRuleResponse(account, &rule)
|
resp := toRuleResponse(account, &rule)
|
||||||
|
|
||||||
writeJSONObject(w, &resp)
|
util.WriteJSONObject(w, &resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRuleHandler handles rule deletion request
|
// DeleteRuleHandler handles rule deletion request
|
||||||
func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
aID := account.Id
|
aID := account.Id
|
||||||
|
|
||||||
rID := mux.Vars(r)["id"]
|
rID := mux.Vars(r)["id"]
|
||||||
if len(rID) == 0 {
|
if len(rID) == 0 {
|
||||||
http.Error(w, "invalid rule ID", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.accountManager.DeleteRule(aID, rID); err != nil {
|
err = h.accountManager.DeleteRule(aID, rID)
|
||||||
log.Errorf("failed delete rule %s under account %s %v", rID, aID, err)
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, "")
|
util.WriteJSONObject(w, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleHandler handles a group Get request identified by ID
|
// GetRuleHandler handles a group Get request identified by ID
|
||||||
func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -380,19 +361,19 @@ func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
ruleID := mux.Vars(r)["id"]
|
ruleID := mux.Vars(r)["id"]
|
||||||
if len(ruleID) == 0 {
|
if len(ruleID) == 0 {
|
||||||
http.Error(w, "invalid rule ID", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err := h.accountManager.GetRule(account.Id, ruleID, user.Id)
|
rule, err := h.accountManager.GetRule(account.Id, ruleID, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "rule not found", http.StatusNotFound)
|
util.WriteError(status.Errorf(status.NotFound, "rule not found"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, toRuleResponse(account, rule))
|
util.WriteJSONObject(w, toRuleResponse(account, rule))
|
||||||
default:
|
default:
|
||||||
http.Error(w, "", http.StatusNotFound)
|
util.WriteError(status.Errorf(status.NotFound, "method not found"), w)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,7 +66,8 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
|
|||||||
}
|
}
|
||||||
return &rule, nil
|
return &rule, nil
|
||||||
},
|
},
|
||||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
|
user := server.NewAdminUser("test_user")
|
||||||
return &server.Account{
|
return &server.Account{
|
||||||
Id: claims.AccountId,
|
Id: claims.AccountId,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
@ -76,9 +77,9 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
|
|||||||
"G": {ID: "G"},
|
"G": {ID: "G"},
|
||||||
},
|
},
|
||||||
Users: map[string]*server.User{
|
Users: map[string]*server.User{
|
||||||
"test_user": server.NewAdminUser("test_user"),
|
"test_user": user,
|
||||||
},
|
},
|
||||||
}, nil
|
}, user, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
authAudience: "",
|
authAudience: "",
|
||||||
@ -238,7 +239,7 @@ func TestRulesWriteRule(t *testing.T) {
|
|||||||
requestPath: "/api/rules/id-existed",
|
requestPath: "/api/rules/id-existed",
|
||||||
requestBody: bytes.NewBuffer(
|
requestBody: bytes.NewBuffer(
|
||||||
[]byte(`[{"op":"insert","path":"name","value":[""]}]`)),
|
[]byte(`[{"op":"insert","path":"name","value":[""]}]`)),
|
||||||
expectedStatus: http.StatusBadRequest,
|
expectedStatus: http.StatusUnprocessableEntity,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -2,14 +2,12 @@ package http
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
log "github.com/sirupsen/logrus"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -31,29 +29,28 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authAudience stri
|
|||||||
|
|
||||||
// CreateSetupKeyHandler is a POST requests that creates a new SetupKey
|
// CreateSetupKeyHandler is a POST requests that creates a new SetupKey
|
||||||
func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req := &api.PostApiSetupKeysJSONRequestBody{}
|
req := &api.PostApiSetupKeysJSONRequestBody{}
|
||||||
err = json.NewDecoder(r.Body).Decode(&req)
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Name == "" {
|
if req.Name == "" {
|
||||||
http.Error(w, "Setup key name shouldn't be empty", http.StatusUnprocessableEntity)
|
util.WriteError(status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable ||
|
if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable ||
|
||||||
server.SetupKeyType(req.Type) == server.SetupKeyOneOff) {
|
server.SetupKeyType(req.Type) == server.SetupKeyOneOff) {
|
||||||
|
util.WriteError(status.Errorf(status.InvalidArgument, "unknown setup key type %s", string(req.Type)), w)
|
||||||
http.Error(w, "unknown setup key type "+string(req.Type), http.StatusBadRequest)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,17 +59,11 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
|
|||||||
if req.AutoGroups == nil {
|
if req.AutoGroups == nil {
|
||||||
req.AutoGroups = []string{}
|
req.AutoGroups = []string{}
|
||||||
}
|
}
|
||||||
// newExpiresIn := time.Duration(req.ExpiresIn) * time.Second
|
|
||||||
// newKey.ExpiresAt = time.Now().Add(newExpiresIn)
|
|
||||||
setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
||||||
req.AutoGroups)
|
req.AutoGroups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStatus, ok := status.FromError(err)
|
util.WriteError(err, w)
|
||||||
if ok && errStatus.Code() == codes.NotFound {
|
|
||||||
http.Error(w, "account not found", http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
http.Error(w, "failed adding setup key", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,29 +72,23 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
|
|||||||
|
|
||||||
// GetSetupKeyHandler is a GET request to get a SetupKey by ID
|
// GetSetupKeyHandler is a GET request to get a SetupKey by ID
|
||||||
func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
keyID := vars["id"]
|
keyID := vars["id"]
|
||||||
if len(keyID) == 0 {
|
if len(keyID) == 0 {
|
||||||
http.Error(w, "invalid key Id", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID)
|
key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStatus, ok := status.FromError(err)
|
util.WriteError(err, w)
|
||||||
if ok && errStatus.Code() == codes.NotFound {
|
|
||||||
http.Error(w, fmt.Sprintf("setup key %s not found under account %s", keyID, account.Id), http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Errorf("failed getting setup key %s under account %s %v", keyID, account.Id, err)
|
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,34 +97,34 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// UpdateSetupKeyHandler is a PUT request to update server.SetupKey
|
// UpdateSetupKeyHandler is a PUT request to update server.SetupKey
|
||||||
func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
keyID := vars["id"]
|
keyID := vars["id"]
|
||||||
if len(keyID) == 0 {
|
if len(keyID) == 0 {
|
||||||
http.Error(w, "invalid key Id", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req := &api.PutApiSetupKeysIdJSONRequestBody{}
|
req := &api.PutApiSetupKeysIdJSONRequestBody{}
|
||||||
err = json.NewDecoder(r.Body).Decode(&req)
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Name == "" {
|
if req.Name == "" {
|
||||||
http.Error(w, fmt.Sprintf("setup key name field is invalid: %s", req.Name), http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.AutoGroups == nil {
|
if req.AutoGroups == nil {
|
||||||
http.Error(w, fmt.Sprintf("setup key AutoGroups field is invalid: %s", req.AutoGroups), http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -150,16 +135,8 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request
|
|||||||
newKey.Id = keyID
|
newKey.Id = keyID
|
||||||
|
|
||||||
newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey)
|
newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if e, ok := status.FromError(err); ok {
|
util.WriteError(err, w)
|
||||||
switch e.Code() {
|
|
||||||
case codes.NotFound:
|
|
||||||
http.Error(w, fmt.Sprintf("couldn't find setup key for ID %s", keyID), http.StatusNotFound)
|
|
||||||
default:
|
|
||||||
http.Error(w, "failed updating setup key", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
writeSuccess(w, newKey)
|
writeSuccess(w, newKey)
|
||||||
@ -168,25 +145,25 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request
|
|||||||
// GetAllSetupKeysHandler is a GET request that returns a list of SetupKey
|
// GetAllSetupKeysHandler is a GET request that returns a list of SetupKey
|
||||||
func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id)
|
setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
apiSetupKeys := make([]*api.SetupKey, 0)
|
apiSetupKeys := make([]*api.SetupKey, 0)
|
||||||
for _, key := range setupKeys {
|
for _, key := range setupKeys {
|
||||||
apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
|
apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, apiSetupKeys)
|
util.WriteJSONObject(w, apiSetupKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
|
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
|
||||||
@ -194,7 +171,7 @@ func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
|
|||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
err := json.NewEncoder(w).Encode(toResponseBody(key))
|
err := json.NewEncoder(w).Encode(toResponseBody(key))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "failed handling request", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,9 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@ -32,7 +31,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
|||||||
user *server.User) *SetupKeys {
|
user *server.User) *SetupKeys {
|
||||||
return &SetupKeys{
|
return &SetupKeys{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
return &server.Account{
|
return &server.Account{
|
||||||
Id: testAccountID,
|
Id: testAccountID,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
@ -45,7 +44,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
|||||||
Groups: map[string]*server.Group{
|
Groups: map[string]*server.Group{
|
||||||
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
|
"group-1": {ID: "group-1", Peers: []string{"A", "B"}},
|
||||||
"id-all": {ID: "id-all", Name: "All"}},
|
"id-all": {ID: "id-all", Name: "All"}},
|
||||||
}, nil
|
}, user, nil
|
||||||
},
|
},
|
||||||
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string) (*server.SetupKey, error) {
|
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string) (*server.SetupKey, error) {
|
||||||
if keyName == newKey.Name || typ != newKey.Type {
|
if keyName == newKey.Name || typ != newKey.Type {
|
||||||
@ -60,7 +59,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
|||||||
case newKey.Id:
|
case newKey.Id:
|
||||||
return newKey, nil
|
return newKey, nil
|
||||||
default:
|
default:
|
||||||
return nil, status.Errorf(codes.NotFound, "key %s not found", keyID)
|
return nil, status.Errorf(status.NotFound, "key %s not found", keyID)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -68,7 +67,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
|||||||
if key.Id == updatedSetupKey.Id {
|
if key.Id == updatedSetupKey.Id {
|
||||||
return updatedSetupKey, nil
|
return updatedSetupKey, nil
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.NotFound, "key %s not found", key.Id)
|
return nil, status.Errorf(status.NotFound, "key %s not found", key.Id)
|
||||||
},
|
},
|
||||||
|
|
||||||
ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) {
|
ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) {
|
||||||
|
@ -2,12 +2,10 @@ package http
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
log "github.com/sirupsen/logrus"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"google.golang.org/grpc/codes"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
@ -31,33 +29,34 @@ func NewUserHandler(accountManager server.AccountManager, authAudience string) *
|
|||||||
// UpdateUser is a PUT requests to update User data
|
// UpdateUser is a PUT requests to update User data
|
||||||
func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodPut {
|
if r.Method != http.MethodPut {
|
||||||
http.Error(w, "", http.StatusBadRequest)
|
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
userID := vars["id"]
|
userID := vars["id"]
|
||||||
if len(userID) == 0 {
|
if len(userID) == 0 {
|
||||||
http.Error(w, "invalid user ID", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req := &api.PutApiUsersIdJSONRequestBody{}
|
req := &api.PutApiUsersIdJSONRequestBody{}
|
||||||
err = json.NewDecoder(r.Body).Decode(&req)
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userRole := server.StrRoleToUserRole(req.Role)
|
userRole := server.StrRoleToUserRole(req.Role)
|
||||||
if userRole == server.UserRoleUnknown {
|
if userRole == server.UserRoleUnknown {
|
||||||
http.Error(w, "invalid user role", http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,40 +66,36 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
AutoGroups: req.AutoGroups,
|
AutoGroups: req.AutoGroups,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if e, ok := status.FromError(err); ok {
|
util.WriteError(err, w)
|
||||||
switch e.Code() {
|
|
||||||
case codes.NotFound:
|
|
||||||
http.Error(w, fmt.Sprintf("couldn't find a user for ID %s", userID), http.StatusNotFound)
|
|
||||||
default:
|
|
||||||
http.Error(w, "failed to update user", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
writeJSONObject(w, toUserResponse(newUser))
|
util.WriteJSONObject(w, toUserResponse(newUser))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateUserHandler creates a User in the system with a status "invited" (effectively this is a user invite).
|
// CreateUserHandler creates a User in the system with a status "invited" (effectively this is a user invite).
|
||||||
func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
http.Error(w, "", http.StatusNotFound)
|
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req := &api.PostApiUsersJSONRequestBody{}
|
req := &api.PostApiUsersJSONRequestBody{}
|
||||||
err = json.NewDecoder(r.Body).Decode(&req)
|
err = json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown {
|
if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown {
|
||||||
http.Error(w, "unknown user role "+req.Role, http.StatusBadRequest)
|
util.WriteError(status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -111,37 +106,30 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request)
|
|||||||
AutoGroups: req.AutoGroups,
|
AutoGroups: req.AutoGroups,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if e, ok := server.FromError(err); ok {
|
util.WriteError(err, w)
|
||||||
switch e.Type() {
|
|
||||||
case server.UserAlreadyExists:
|
|
||||||
http.Error(w, "You can't invite users with an existing NetBird account.", http.StatusPreconditionFailed)
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
http.Error(w, "failed to invite", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
writeJSONObject(w, toUserResponse(newUser))
|
util.WriteJSONObject(w, toUserResponse(newUser))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUsers returns a list of users of the account this user belongs to.
|
// GetUsers returns a list of users of the account this user belongs to.
|
||||||
// It also gathers additional user data (like email and name) from the IDP manager.
|
// It also gathers additional user data (like email and name) from the IDP manager.
|
||||||
func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
|
func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodGet {
|
if r.Method != http.MethodGet {
|
||||||
http.Error(w, "", http.StatusBadRequest)
|
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
util.WriteError(err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id)
|
data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
util.WriteError(err, w)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -150,7 +138,7 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
|
|||||||
users = append(users, toUserResponse(r))
|
users = append(users, toUserResponse(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, users)
|
util.WriteJSONObject(w, users)
|
||||||
}
|
}
|
||||||
|
|
||||||
func toUserResponse(user *server.UserInfo) *api.User {
|
func toUserResponse(user *server.UserInfo) *api.User {
|
||||||
|
@ -16,7 +16,7 @@ import (
|
|||||||
func initUsers(user ...*server.User) *UserHandler {
|
func initUsers(user ...*server.User) *UserHandler {
|
||||||
return &UserHandler{
|
return &UserHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||||
users := make(map[string]*server.User, 0)
|
users := make(map[string]*server.User, 0)
|
||||||
for _, u := range user {
|
for _, u := range user {
|
||||||
users[u.Id] = u
|
users[u.Id] = u
|
||||||
@ -25,7 +25,7 @@ func initUsers(user ...*server.User) *UserHandler {
|
|||||||
Id: "12345",
|
Id: "12345",
|
||||||
Domain: "netbird.io",
|
Domain: "netbird.io",
|
||||||
Users: users,
|
Users: users,
|
||||||
}, nil
|
}, users[claims.UserId], nil
|
||||||
},
|
},
|
||||||
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
|
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
|
||||||
users := make([]*server.UserInfo, 0)
|
users := make([]*server.UserInfo, 0)
|
||||||
@ -66,7 +66,6 @@ func TestGetUsers(t *testing.T) {
|
|||||||
expectedResult []*server.User
|
expectedResult []*server.User
|
||||||
}{
|
}{
|
||||||
{name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users/", expectedStatus: http.StatusOK, expectedResult: users},
|
{name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users/", expectedStatus: http.StatusOK, expectedResult: users},
|
||||||
{name: "WrongRequestMethod", requestType: http.MethodPost, requestPath: "/api/users/", expectedStatus: http.StatusBadRequest},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
|
@ -1,97 +0,0 @@
|
|||||||
package http
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// writeJSONObject simply writes object to the HTTP reponse in JSON format
|
|
||||||
func writeJSONObject(w http.ResponseWriter, obj interface{}) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
|
||||||
err := json.NewEncoder(w).Encode(obj)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "failed handling request", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Duration is used strictly for JSON requests/responses due to duration marshalling issues
|
|
||||||
type Duration struct {
|
|
||||||
time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d Duration) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(d.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Duration) UnmarshalJSON(b []byte) error {
|
|
||||||
var v interface{}
|
|
||||||
if err := json.Unmarshal(b, &v); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
switch value := v.(type) {
|
|
||||||
case float64:
|
|
||||||
d.Duration = time.Duration(value)
|
|
||||||
return nil
|
|
||||||
case string:
|
|
||||||
var err error
|
|
||||||
d.Duration, err = time.ParseDuration(value)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return errors.New("invalid duration")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getJWTAccount(accountManager server.AccountManager,
|
|
||||||
jwtExtractor jwtclaims.ClaimsExtractor,
|
|
||||||
authAudience string, r *http.Request) (*server.Account, *server.User, error) {
|
|
||||||
|
|
||||||
claims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience)
|
|
||||||
|
|
||||||
account, err := accountManager.GetAccountFromToken(claims)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("failed getting account of a user %s: %v", claims.UserId, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
user := account.Users[claims.UserId]
|
|
||||||
if user == nil {
|
|
||||||
// this is not really possible because we got an account by user ID
|
|
||||||
return nil, nil, fmt.Errorf("user %s not found", claims.UserId)
|
|
||||||
}
|
|
||||||
|
|
||||||
return account, user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func toHTTPError(err error, w http.ResponseWriter) {
|
|
||||||
errStatus, ok := status.FromError(err)
|
|
||||||
if ok && errStatus.Code() == codes.Internal {
|
|
||||||
http.Error(w, errStatus.String(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok && errStatus.Code() == codes.NotFound {
|
|
||||||
http.Error(w, errStatus.String(), http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok && errStatus.Code() == codes.InvalidArgument {
|
|
||||||
http.Error(w, errStatus.String(), http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", errStatus.String())
|
|
||||||
log.Error(unhandledMSG)
|
|
||||||
http.Error(w, unhandledMSG, http.StatusInternalServerError)
|
|
||||||
}
|
|
105
management/server/http/util/util.go
Normal file
105
management/server/http/util/util.go
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WriteJSONObject simply writes object to the HTTP reponse in JSON format
|
||||||
|
func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||||
|
err := json.NewEncoder(w).Encode(obj)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Duration is used strictly for JSON requests/responses due to duration marshalling issues
|
||||||
|
type Duration struct {
|
||||||
|
time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON marshals the duration
|
||||||
|
func (d Duration) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(d.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON unmarshals the duration
|
||||||
|
func (d *Duration) UnmarshalJSON(b []byte) error {
|
||||||
|
var v interface{}
|
||||||
|
if err := json.Unmarshal(b, &v); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
switch value := v.(type) {
|
||||||
|
case float64:
|
||||||
|
d.Duration = time.Duration(value)
|
||||||
|
return nil
|
||||||
|
case string:
|
||||||
|
var err error
|
||||||
|
d.Duration, err = time.ParseDuration(value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return errors.New("invalid duration")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteErrorResponse prepares and writes an error response i nJSON
|
||||||
|
func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
|
||||||
|
type errorResponse struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Code int `json:"code"`
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(httpStatus)
|
||||||
|
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||||
|
err := json.NewEncoder(w).Encode(&errorResponse{
|
||||||
|
Message: errMsg,
|
||||||
|
Code: httpStatus,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "failed handling request", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteError converts an error to an JSON error response.
|
||||||
|
// If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise
|
||||||
|
func WriteError(err error, w http.ResponseWriter) {
|
||||||
|
errStatus, ok := status.FromError(err)
|
||||||
|
httpStatus := http.StatusInternalServerError
|
||||||
|
msg := "internal server error"
|
||||||
|
if ok {
|
||||||
|
switch errStatus.Type() {
|
||||||
|
case status.UserAlreadyExists:
|
||||||
|
httpStatus = http.StatusConflict
|
||||||
|
case status.AlreadyExists:
|
||||||
|
httpStatus = http.StatusConflict
|
||||||
|
case status.PreconditionFailed:
|
||||||
|
httpStatus = http.StatusPreconditionFailed
|
||||||
|
case status.PermissionDenied:
|
||||||
|
httpStatus = http.StatusForbidden
|
||||||
|
case status.NotFound:
|
||||||
|
httpStatus = http.StatusNotFound
|
||||||
|
case status.Internal:
|
||||||
|
httpStatus = http.StatusInternalServerError
|
||||||
|
case status.InvalidArgument:
|
||||||
|
httpStatus = http.StatusUnprocessableEntity
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
msg = err.Error()
|
||||||
|
} else {
|
||||||
|
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error())
|
||||||
|
log.Error(unhandledMSG)
|
||||||
|
}
|
||||||
|
|
||||||
|
WriteErrorResponse(msg, httpStatus, w)
|
||||||
|
}
|
@ -59,7 +59,7 @@ type MockAccountManager struct {
|
|||||||
DeleteNameServerGroupFunc func(accountID, nsGroupID string) error
|
DeleteNameServerGroupFunc func(accountID, nsGroupID string) error
|
||||||
ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error)
|
ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error)
|
||||||
CreateUserFunc func(accountID string, key *server.UserInfo) (*server.UserInfo, error)
|
CreateUserFunc func(accountID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||||
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error)
|
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
||||||
@ -113,7 +113,7 @@ func (am *MockAccountManager) CreateSetupKey(
|
|||||||
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByUserOrAccountId mock implementation of GetAccountByUserOrAccountId from server.AccountManager interface
|
// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetAccountByUserOrAccountID(
|
func (am *MockAccountManager) GetAccountByUserOrAccountID(
|
||||||
userId, accountId, domain string,
|
userId, accountId, domain string,
|
||||||
) (*server.Account, error) {
|
) (*server.Account, error) {
|
||||||
@ -462,11 +462,12 @@ func (am *MockAccountManager) CreateUser(accountID string, invite *server.UserIn
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface
|
// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface
|
||||||
func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User,
|
||||||
|
error) {
|
||||||
if am.GetAccountFromTokenFunc != nil {
|
if am.GetAccountFromTokenFunc != nil {
|
||||||
return am.GetAccountFromTokenFunc(claims)
|
return am.GetAccountFromTokenFunc(claims)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
|
return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPeers mocks GetPeers of the AccountManager interface
|
// GetPeers mocks GetPeers of the AccountManager interface
|
||||||
|
@ -3,10 +3,9 @@ package server
|
|||||||
import (
|
import (
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
@ -66,7 +65,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string)
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroup, found := account.NameServerGroups[nsGroupID]
|
nsGroup, found := account.NameServerGroups[nsGroupID]
|
||||||
@ -74,7 +73,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string)
|
|||||||
return nsGroup.Copy(), nil
|
return nsGroup.Copy(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID)
|
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateNameServerGroup creates and saves a new nameserver group
|
// CreateNameServerGroup creates and saves a new nameserver group
|
||||||
@ -85,7 +84,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
newNSGroup := &nbdns.NameServerGroup{
|
newNSGroup := &nbdns.NameServerGroup{
|
||||||
@ -119,7 +118,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
|
|||||||
err = am.updateAccountPeers(account)
|
err = am.updateAccountPeers(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return newNSGroup.Copy(), status.Errorf(codes.Unavailable, "failed to update peers after create nameserver %s", name)
|
return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after create nameserver %s", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
return newNSGroup.Copy(), nil
|
return newNSGroup.Copy(), nil
|
||||||
@ -132,12 +131,12 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupTo
|
|||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
if nsGroupToSave == nil {
|
if nsGroupToSave == nil {
|
||||||
return status.Errorf(codes.InvalidArgument, "nameserver group provided is nil")
|
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.NotFound, "account not found")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validateNameServerGroup(true, nsGroupToSave, account)
|
err = validateNameServerGroup(true, nsGroupToSave, account)
|
||||||
@ -156,7 +155,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupTo
|
|||||||
err = am.updateAccountPeers(account)
|
err = am.updateAccountPeers(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return status.Errorf(codes.Unavailable, "failed to update peers after update nameserver %s", nsGroupToSave.Name)
|
return status.Errorf(status.Internal, "failed to update peers after update nameserver %s", nsGroupToSave.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -170,16 +169,16 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(operations) == 0 {
|
if len(operations) == 0 {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "operations shouldn't be empty")
|
return nil, status.Errorf(status.InvalidArgument, "operations shouldn't be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroupToUpdate, ok := account.NameServerGroups[nsGroupID]
|
nsGroupToUpdate, ok := account.NameServerGroups[nsGroupID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, status.Errorf(codes.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
|
return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
newNSGroup := nsGroupToUpdate.Copy()
|
newNSGroup := nsGroupToUpdate.Copy()
|
||||||
@ -187,12 +186,12 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
|
|||||||
for _, operation := range operations {
|
for _, operation := range operations {
|
||||||
valuesCount := len(operation.Values)
|
valuesCount := len(operation.Values)
|
||||||
if valuesCount < 1 {
|
if valuesCount < 1 {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String())
|
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, value := range operation.Values {
|
for _, value := range operation.Values {
|
||||||
if value == "" {
|
if value == "" {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String())
|
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
switch operation.Type {
|
switch operation.Type {
|
||||||
@ -200,7 +199,7 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
|
|||||||
newNSGroup.Description = operation.Values[0]
|
newNSGroup.Description = operation.Values[0]
|
||||||
case UpdateNameServerGroupName:
|
case UpdateNameServerGroupName:
|
||||||
if valuesCount > 1 {
|
if valuesCount > 1 {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount)
|
return nil, status.Errorf(status.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount)
|
||||||
}
|
}
|
||||||
err = validateNSGroupName(operation.Values[0], nsGroupID, account.NameServerGroups)
|
err = validateNSGroupName(operation.Values[0], nsGroupID, account.NameServerGroups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -230,13 +229,13 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
|
|||||||
case UpdateNameServerGroupEnabled:
|
case UpdateNameServerGroupEnabled:
|
||||||
enabled, err := strconv.ParseBool(operation.Values[0])
|
enabled, err := strconv.ParseBool(operation.Values[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
|
return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
|
||||||
}
|
}
|
||||||
newNSGroup.Enabled = enabled
|
newNSGroup.Enabled = enabled
|
||||||
case UpdateNameServerGroupPrimary:
|
case UpdateNameServerGroupPrimary:
|
||||||
primary, err := strconv.ParseBool(operation.Values[0])
|
primary, err := strconv.ParseBool(operation.Values[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0])
|
return nil, status.Errorf(status.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0])
|
||||||
}
|
}
|
||||||
newNSGroup.Primary = primary
|
newNSGroup.Primary = primary
|
||||||
case UpdateNameServerGroupDomains:
|
case UpdateNameServerGroupDomains:
|
||||||
@ -259,7 +258,7 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri
|
|||||||
err = am.updateAccountPeers(account)
|
err = am.updateAccountPeers(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return newNSGroup.Copy(), status.Errorf(codes.Unavailable, "failed to update peers after update nameserver %s", newNSGroup.Name)
|
return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after update nameserver %s", newNSGroup.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
return newNSGroup.Copy(), nil
|
return newNSGroup.Copy(), nil
|
||||||
@ -273,7 +272,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID stri
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.NotFound, "account not found")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(account.NameServerGroups, nsGroupID)
|
delete(account.NameServerGroups, nsGroupID)
|
||||||
@ -287,7 +286,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID stri
|
|||||||
err = am.updateAccountPeers(account)
|
err = am.updateAccountPeers(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return status.Errorf(codes.Unavailable, "failed to update peers after deleting nameserver %s", nsGroupID)
|
return status.Errorf(status.Internal, "failed to update peers after deleting nameserver %s", nsGroupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -301,7 +300,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdn
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups))
|
nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups))
|
||||||
@ -318,7 +317,7 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
|
|||||||
nsGroupID = nameserverGroup.ID
|
nsGroupID = nameserverGroup.ID
|
||||||
_, found := account.NameServerGroups[nsGroupID]
|
_, found := account.NameServerGroups[nsGroupID]
|
||||||
if !found {
|
if !found {
|
||||||
return status.Errorf(codes.NotFound, "nameserver group with ID %s was not found", nsGroupID)
|
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -347,17 +346,17 @@ func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServ
|
|||||||
|
|
||||||
func validateDomainInput(primary bool, domains []string) error {
|
func validateDomainInput(primary bool, domains []string) error {
|
||||||
if !primary && len(domains) == 0 {
|
if !primary && len(domains) == 0 {
|
||||||
return status.Errorf(codes.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
|
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
|
||||||
" it should be primary or have at least one domain")
|
" it should be primary or have at least one domain")
|
||||||
}
|
}
|
||||||
if primary && len(domains) != 0 {
|
if primary && len(domains) != 0 {
|
||||||
return status.Errorf(codes.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+
|
return status.Errorf(status.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+
|
||||||
" you should set either primary or domain")
|
" you should set either primary or domain")
|
||||||
}
|
}
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
_, valid := dns.IsDomainName(domain)
|
_, valid := dns.IsDomainName(domain)
|
||||||
if !valid {
|
if !valid {
|
||||||
return status.Errorf(codes.InvalidArgument, "nameserver group got an invalid domain: %s", domain)
|
return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s", domain)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -365,12 +364,12 @@ func validateDomainInput(primary bool, domains []string) error {
|
|||||||
|
|
||||||
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
|
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
|
||||||
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
|
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
|
||||||
return status.Errorf(codes.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
|
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, nsGroup := range nsGroupMap {
|
for _, nsGroup := range nsGroupMap {
|
||||||
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
|
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
|
||||||
return status.Errorf(codes.InvalidArgument, "a nameserver group with name %s already exist", name)
|
return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -380,19 +379,19 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na
|
|||||||
func validateNSList(list []nbdns.NameServer) error {
|
func validateNSList(list []nbdns.NameServer) error {
|
||||||
nsListLenght := len(list)
|
nsListLenght := len(list)
|
||||||
if nsListLenght == 0 || nsListLenght > 2 {
|
if nsListLenght == 0 || nsListLenght > 2 {
|
||||||
return status.Errorf(codes.InvalidArgument, "the list of nameservers should be 1 or 2, got %d", len(list))
|
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 2, got %d", len(list))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateGroups(list []string, groups map[string]*Group) error {
|
func validateGroups(list []string, groups map[string]*Group) error {
|
||||||
if len(list) == 0 {
|
if len(list) == 0 {
|
||||||
return status.Errorf(codes.InvalidArgument, "the list of group IDs should not be empty")
|
return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, id := range list {
|
for _, id := range list {
|
||||||
if id == "" {
|
if id == "" {
|
||||||
return status.Errorf(codes.InvalidArgument, "group ID should not be empty string")
|
return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
|
||||||
}
|
}
|
||||||
found := false
|
found := false
|
||||||
for groupID := range groups {
|
for groupID := range groups {
|
||||||
@ -402,7 +401,7 @@ func validateGroups(list []string, groups map[string]*Group) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !found {
|
if !found {
|
||||||
return status.Errorf(codes.InvalidArgument, "group id %s not found", id)
|
return status.Errorf(status.InvalidArgument, "group id %s not found", id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"github.com/c-robinson/iplib"
|
"github.com/c-robinson/iplib"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
@ -93,7 +94,7 @@ func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
|
|||||||
ips, _ := generateIPs(&ipNet, takenIPMap)
|
ips, _ := generateIPs(&ipNet, takenIPMap)
|
||||||
|
|
||||||
if len(ips) == 0 {
|
if len(ips) == 0 {
|
||||||
return nil, Errorf(PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String())
|
return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// pick a random IP
|
// pick a random IP
|
||||||
|
@ -2,6 +2,7 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -9,8 +10,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// PeerSystemMeta is a metadata of a Peer machine system
|
// PeerSystemMeta is a metadata of a Peer machine system
|
||||||
@ -162,7 +161,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Pe
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO Peer.ID migration: we will need to replace search by ID here
|
//TODO Peer.ID migration: we will need to replace search by ID here
|
||||||
@ -208,7 +207,7 @@ func (am *DefaultAccountManager) DeletePeer(accountID string, peerPubKey string)
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, err := account.FindPeerByPubKey(peerPubKey)
|
peer, err := account.FindPeerByPubKey(peerPubKey)
|
||||||
@ -258,7 +257,7 @@ func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, peer := range account.Peers {
|
for _, peer := range account.Peers {
|
||||||
@ -267,7 +266,7 @@ func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, status.Errorf(codes.NotFound, "peer with IP %s not found", peerIP)
|
return nil, status.Errorf(status.NotFound, "peer with IP %s not found", peerIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
||||||
@ -275,7 +274,7 @@ func (am *DefaultAccountManager) GetNetworkMap(peerPubKey string) (*NetworkMap,
|
|||||||
|
|
||||||
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "Invalid peer key %s", peerPubKey)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
aclPeers := am.getPeersByACL(account, peerPubKey)
|
aclPeers := am.getPeersByACL(account, peerPubKey)
|
||||||
@ -306,7 +305,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(peerPubKey string) (*Network, er
|
|||||||
|
|
||||||
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "invalid peer key %s", peerPubKey)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return account.Network.Copy(), err
|
return account.Network.Copy(), err
|
||||||
@ -332,7 +331,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
|
|||||||
account, err = am.Store.GetAccountBySetupKey(setupKey)
|
account, err = am.Store.GetAccountBySetupKey(setupKey)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, Errorf(AccountNotFound, "failed adding new peer: account not found")
|
return nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||||
@ -352,7 +351,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !sk.IsValid() {
|
if !sk.IsValid() {
|
||||||
return nil, Errorf(PreconditionFailed, "couldn't add peer: setup key is invalid")
|
return nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
account.SetupKeys[sk.Key] = sk.IncrementUsage()
|
account.SetupKeys[sk.Key] = sk.IncrementUsage()
|
||||||
@ -418,7 +417,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
|
|||||||
account.Network.IncSerial()
|
account.Network.IncSerial()
|
||||||
err = am.Store.SaveAccount(account)
|
err = am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "failed adding peer")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return newPeer, nil
|
return newPeer, nil
|
||||||
@ -563,13 +562,13 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) error {
|
|||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
remotePeerNetworkMap, err := am.GetNetworkMap(peer.Key)
|
remotePeerNetworkMap, err := am.GetNetworkMap(peer.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.Internal, "unable to fetch network map for peer %s, error: %v", peer.Key, err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap)
|
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap)
|
||||||
err = am.peersUpdateManager.SendUpdate(peer.Key, &UpdateMessage{Update: update})
|
err = am.peersUpdateManager.SendUpdate(peer.Key, &UpdateMessage{Update: update})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.Internal, "unable to send update for peer %s, error: %v", peer.Key, err)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,11 +2,10 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
@ -66,7 +65,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
user, err := account.FindUser(userID)
|
||||||
@ -75,7 +74,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsAdmin() {
|
if !user.IsAdmin() {
|
||||||
return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes")
|
return nil, status.Errorf(status.PermissionDenied, "Only administrators can view Network Routes")
|
||||||
}
|
}
|
||||||
|
|
||||||
wantedRoute, found := account.Routes[routeID]
|
wantedRoute, found := account.Routes[routeID]
|
||||||
@ -83,7 +82,7 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*r
|
|||||||
return wantedRoute, nil
|
return wantedRoute, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, status.Errorf(codes.NotFound, "route with ID %s not found", routeID)
|
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkPrefixPeerExists checks the combination of prefix and peer id, if it exists returns an error, otherwise returns nil
|
// checkPrefixPeerExists checks the combination of prefix and peer id, if it exists returns an error, otherwise returns nil
|
||||||
@ -101,14 +100,14 @@ func (am *DefaultAccountManager) checkPrefixPeerExists(accountID, peer string, p
|
|||||||
routesWithPrefix := account.GetRoutesByPrefix(prefix)
|
routesWithPrefix := account.GetRoutesByPrefix(prefix)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.InvalidArgument, "failed to parse prefix %s", prefix.String())
|
return status.Errorf(status.InvalidArgument, "failed to parse prefix %s", prefix.String())
|
||||||
}
|
}
|
||||||
for _, prefixRoute := range routesWithPrefix {
|
for _, prefixRoute := range routesWithPrefix {
|
||||||
if prefixRoute.Peer == peer {
|
if prefixRoute.Peer == peer {
|
||||||
return status.Errorf(codes.AlreadyExists, "failed a route with prefix %s and peer already exist", prefix.String())
|
return status.Errorf(status.AlreadyExists, "failed a route with prefix %s and peer already exist", prefix.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -121,13 +120,13 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var newRoute route.Route
|
var newRoute route.Route
|
||||||
prefixType, newPrefix, err := route.ParseNetwork(network)
|
prefixType, newPrefix, err := route.ParseNetwork(network)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse IP %s", network)
|
return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", network)
|
||||||
}
|
}
|
||||||
err = am.checkPrefixPeerExists(accountID, peer, newPrefix)
|
err = am.checkPrefixPeerExists(accountID, peer, newPrefix)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -137,16 +136,16 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de
|
|||||||
if peer != "" {
|
if peer != "" {
|
||||||
_, peerExist := account.Peers[peer]
|
_, peerExist := account.Peers[peer]
|
||||||
if !peerExist {
|
if !peerExist {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "failed to find Peer %s", peer)
|
return nil, status.Errorf(status.InvalidArgument, "failed to find Peer %s", peer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if metric < route.MinMetric || metric > route.MaxMetric {
|
if metric < route.MinMetric || metric > route.MaxMetric {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
||||||
}
|
}
|
||||||
|
|
||||||
if utf8.RuneCountInString(netID) > route.MaxNetIDChar || netID == "" {
|
if utf8.RuneCountInString(netID) > route.MaxNetIDChar || netID == "" {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||||
}
|
}
|
||||||
|
|
||||||
newRoute.Peer = peer
|
newRoute.Peer = peer
|
||||||
@ -173,7 +172,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de
|
|||||||
err = am.updateAccountPeers(account)
|
err = am.updateAccountPeers(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return &newRoute, status.Errorf(codes.Unavailable, "failed to update peers after create route %s", newPrefix)
|
return &newRoute, status.Errorf(status.Internal, "failed to update peers after create route %s", newPrefix)
|
||||||
}
|
}
|
||||||
return &newRoute, nil
|
return &newRoute, nil
|
||||||
}
|
}
|
||||||
@ -184,30 +183,30 @@ func (am *DefaultAccountManager) SaveRoute(accountID string, routeToSave *route.
|
|||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
if routeToSave == nil {
|
if routeToSave == nil {
|
||||||
return status.Errorf(codes.InvalidArgument, "route provided is nil")
|
return status.Errorf(status.InvalidArgument, "route provided is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !routeToSave.Network.IsValid() {
|
if !routeToSave.Network.IsValid() {
|
||||||
return status.Errorf(codes.InvalidArgument, "invalid Prefix %s", routeToSave.Network.String())
|
return status.Errorf(status.InvalidArgument, "invalid Prefix %s", routeToSave.Network.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if routeToSave.Metric < route.MinMetric || routeToSave.Metric > route.MaxMetric {
|
if routeToSave.Metric < route.MinMetric || routeToSave.Metric > route.MaxMetric {
|
||||||
return status.Errorf(codes.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
||||||
}
|
}
|
||||||
|
|
||||||
if utf8.RuneCountInString(routeToSave.NetID) > route.MaxNetIDChar || routeToSave.NetID == "" {
|
if utf8.RuneCountInString(routeToSave.NetID) > route.MaxNetIDChar || routeToSave.NetID == "" {
|
||||||
return status.Errorf(codes.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.NotFound, "account not found")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if routeToSave.Peer != "" {
|
if routeToSave.Peer != "" {
|
||||||
_, peerExist := account.Peers[routeToSave.Peer]
|
_, peerExist := account.Peers[routeToSave.Peer]
|
||||||
if !peerExist {
|
if !peerExist {
|
||||||
return status.Errorf(codes.InvalidArgument, "failed to find Peer %s", routeToSave.Peer)
|
return status.Errorf(status.InvalidArgument, "failed to find Peer %s", routeToSave.Peer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,12 +227,12 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
routeToUpdate, ok := account.Routes[routeID]
|
routeToUpdate, ok := account.Routes[routeID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, status.Errorf(codes.NotFound, "route %s no longer exists", routeID)
|
return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID)
|
||||||
}
|
}
|
||||||
|
|
||||||
newRoute := routeToUpdate.Copy()
|
newRoute := routeToUpdate.Copy()
|
||||||
@ -241,7 +240,7 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
|
|||||||
for _, operation := range operations {
|
for _, operation := range operations {
|
||||||
|
|
||||||
if len(operation.Values) != 1 {
|
if len(operation.Values) != 1 {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid number of values, it should be 1", operation.Type.String())
|
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be 1", operation.Type.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
switch operation.Type {
|
switch operation.Type {
|
||||||
@ -249,13 +248,13 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
|
|||||||
newRoute.Description = operation.Values[0]
|
newRoute.Description = operation.Values[0]
|
||||||
case UpdateRouteNetworkIdentifier:
|
case UpdateRouteNetworkIdentifier:
|
||||||
if utf8.RuneCountInString(operation.Values[0]) > route.MaxNetIDChar || operation.Values[0] == "" {
|
if utf8.RuneCountInString(operation.Values[0]) > route.MaxNetIDChar || operation.Values[0] == "" {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||||
}
|
}
|
||||||
newRoute.NetID = operation.Values[0]
|
newRoute.NetID = operation.Values[0]
|
||||||
case UpdateRouteNetwork:
|
case UpdateRouteNetwork:
|
||||||
prefixType, prefix, err := route.ParseNetwork(operation.Values[0])
|
prefixType, prefix, err := route.ParseNetwork(operation.Values[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse IP %s", operation.Values[0])
|
return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", operation.Values[0])
|
||||||
}
|
}
|
||||||
err = am.checkPrefixPeerExists(accountID, routeToUpdate.Peer, prefix)
|
err = am.checkPrefixPeerExists(accountID, routeToUpdate.Peer, prefix)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -267,7 +266,7 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
|
|||||||
if operation.Values[0] != "" {
|
if operation.Values[0] != "" {
|
||||||
_, peerExist := account.Peers[operation.Values[0]]
|
_, peerExist := account.Peers[operation.Values[0]]
|
||||||
if !peerExist {
|
if !peerExist {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "failed to find Peer %s", operation.Values[0])
|
return nil, status.Errorf(status.InvalidArgument, "failed to find Peer %s", operation.Values[0])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -279,10 +278,10 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
|
|||||||
case UpdateRouteMetric:
|
case UpdateRouteMetric:
|
||||||
metric, err := strconv.Atoi(operation.Values[0])
|
metric, err := strconv.Atoi(operation.Values[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse metric %s, not int", operation.Values[0])
|
return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, not int", operation.Values[0])
|
||||||
}
|
}
|
||||||
if metric < route.MinMetric || metric > route.MaxMetric {
|
if metric < route.MinMetric || metric > route.MaxMetric {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse metric %s, value should be %d > N < %d",
|
return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, value should be %d > N < %d",
|
||||||
operation.Values[0],
|
operation.Values[0],
|
||||||
route.MinMetric,
|
route.MinMetric,
|
||||||
route.MaxMetric,
|
route.MaxMetric,
|
||||||
@ -292,13 +291,13 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
|
|||||||
case UpdateRouteMasquerade:
|
case UpdateRouteMasquerade:
|
||||||
masquerade, err := strconv.ParseBool(operation.Values[0])
|
masquerade, err := strconv.ParseBool(operation.Values[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse masquerade %s, not boolean", operation.Values[0])
|
return nil, status.Errorf(status.InvalidArgument, "failed to parse masquerade %s, not boolean", operation.Values[0])
|
||||||
}
|
}
|
||||||
newRoute.Masquerade = masquerade
|
newRoute.Masquerade = masquerade
|
||||||
case UpdateRouteEnabled:
|
case UpdateRouteEnabled:
|
||||||
enabled, err := strconv.ParseBool(operation.Values[0])
|
enabled, err := strconv.ParseBool(operation.Values[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
|
return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
|
||||||
}
|
}
|
||||||
newRoute.Enabled = enabled
|
newRoute.Enabled = enabled
|
||||||
}
|
}
|
||||||
@ -313,7 +312,7 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio
|
|||||||
|
|
||||||
err = am.updateAccountPeers(account)
|
err = am.updateAccountPeers(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "failed to update account peers")
|
return nil, status.Errorf(status.Internal, "failed to update account peers")
|
||||||
}
|
}
|
||||||
return newRoute, nil
|
return newRoute, nil
|
||||||
}
|
}
|
||||||
@ -325,7 +324,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error {
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.NotFound, "account not found")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(account.Routes, routeID)
|
delete(account.Routes, routeID)
|
||||||
@ -345,7 +344,7 @@ func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
user, err := account.FindUser(userID)
|
||||||
@ -354,7 +353,7 @@ func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsAdmin() {
|
if !user.IsAdmin() {
|
||||||
return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes")
|
return nil, status.Errorf(status.PermissionDenied, "Only administrators can view Network Routes")
|
||||||
}
|
}
|
||||||
|
|
||||||
routes := make([]*route.Route, 0, len(account.Routes))
|
routes := make([]*route.Route, 0, len(account.Routes))
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"google.golang.org/grpc/codes"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -95,7 +94,7 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rul
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
user, err := account.FindUser(userID)
|
||||||
@ -104,7 +103,7 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rul
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsAdmin() {
|
if !user.IsAdmin() {
|
||||||
return nil, Errorf(PermissionDenied, "only admins are allowed to view rules")
|
return nil, status.Errorf(status.PermissionDenied, "only admins are allowed to view rules")
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, ok := account.Rules[ruleID]
|
rule, ok := account.Rules[ruleID]
|
||||||
@ -112,7 +111,7 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rul
|
|||||||
return rule, nil
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, status.Errorf(codes.NotFound, "rule with ID %s not found", ruleID)
|
return nil, status.Errorf(status.NotFound, "rule with ID %s not found", ruleID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveRule of ACL in the store
|
// SaveRule of ACL in the store
|
||||||
@ -122,7 +121,7 @@ func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error {
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.NotFound, "account not found")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
account.Rules[rule.ID] = rule
|
account.Rules[rule.ID] = rule
|
||||||
@ -143,12 +142,12 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleToUpdate, ok := account.Rules[ruleID]
|
ruleToUpdate, ok := account.Rules[ruleID]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, status.Errorf(codes.NotFound, "rule %s no longer exists", ruleID)
|
return nil, status.Errorf(status.NotFound, "rule %s no longer exists", ruleID)
|
||||||
}
|
}
|
||||||
|
|
||||||
rule := ruleToUpdate.Copy()
|
rule := ruleToUpdate.Copy()
|
||||||
@ -161,7 +160,7 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
|
|||||||
rule.Description = operation.Values[0]
|
rule.Description = operation.Values[0]
|
||||||
case UpdateRuleFlow:
|
case UpdateRuleFlow:
|
||||||
if operation.Values[0] != TrafficFlowBidirectString {
|
if operation.Values[0] != TrafficFlowBidirectString {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse flow")
|
return nil, status.Errorf(status.InvalidArgument, "failed to parse flow")
|
||||||
}
|
}
|
||||||
rule.Flow = TrafficFlowBidirect
|
rule.Flow = TrafficFlowBidirect
|
||||||
case UpdateRuleStatus:
|
case UpdateRuleStatus:
|
||||||
@ -170,7 +169,7 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
|
|||||||
} else if strings.ToLower(operation.Values[0]) == "false" {
|
} else if strings.ToLower(operation.Values[0]) == "false" {
|
||||||
rule.Disabled = false
|
rule.Disabled = false
|
||||||
} else {
|
} else {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "failed to parse status")
|
return nil, status.Errorf(status.InvalidArgument, "failed to parse status")
|
||||||
}
|
}
|
||||||
case UpdateSourceGroups:
|
case UpdateSourceGroups:
|
||||||
rule.Source = operation.Values
|
rule.Source = operation.Values
|
||||||
@ -204,7 +203,7 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string,
|
|||||||
|
|
||||||
err = am.updateAccountPeers(account)
|
err = am.updateAccountPeers(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "failed to update account peers")
|
return nil, status.Errorf(status.Internal, "failed to update account peers")
|
||||||
}
|
}
|
||||||
|
|
||||||
return rule, nil
|
return rule, nil
|
||||||
@ -217,7 +216,7 @@ func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error {
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(codes.NotFound, "account not found")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(account.Rules, ruleID)
|
delete(account.Rules, ruleID)
|
||||||
@ -237,7 +236,7 @@ func (am *DefaultAccountManager) ListRules(accountID, userID string) ([]*Rule, e
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
user, err := account.FindUser(userID)
|
||||||
@ -246,7 +245,7 @@ func (am *DefaultAccountManager) ListRules(accountID, userID string) ([]*Rule, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsAdmin() {
|
if !user.IsAdmin() {
|
||||||
return nil, Errorf(PermissionDenied, "Only Administrators can view Access Rules")
|
return nil, status.Errorf(status.PermissionDenied, "Only Administrators can view Access Rules")
|
||||||
}
|
}
|
||||||
|
|
||||||
rules := make([]*Rule, 0, len(account.Rules))
|
rules := make([]*Rule, 0, len(account.Rules))
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"google.golang.org/grpc/codes"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -183,12 +181,12 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, group := range autoGroups {
|
for _, group := range autoGroups {
|
||||||
if _, ok := account.Groups[group]; !ok {
|
if _, ok := account.Groups[group]; !ok {
|
||||||
return nil, fmt.Errorf("group %s doesn't exist", group)
|
return nil, status.Errorf(status.NotFound, "group %s doesn't exist", group)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -197,7 +195,7 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string
|
|||||||
|
|
||||||
err = am.Store.SaveAccount(account)
|
err = am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "failed adding account key")
|
return nil, status.Errorf(status.Internal, "failed adding account key")
|
||||||
}
|
}
|
||||||
|
|
||||||
return setupKey, nil
|
return setupKey, nil
|
||||||
@ -212,12 +210,12 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
|
|||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
if keyToSave == nil {
|
if keyToSave == nil {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "provided setup key to update is nil")
|
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var oldKey *SetupKey
|
var oldKey *SetupKey
|
||||||
@ -228,7 +226,7 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if oldKey == nil {
|
if oldKey == nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "setup key not found")
|
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
// only auto groups, revoked status, and name can be updated for now
|
// only auto groups, revoked status, and name can be updated for now
|
||||||
@ -253,7 +251,7 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set
|
|||||||
defer unlock()
|
defer unlock()
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
user, err := account.FindUser(userID)
|
||||||
@ -282,7 +280,7 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
user, err := account.FindUser(userID)
|
||||||
@ -298,7 +296,7 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if foundKey == nil {
|
if foundKey == nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "setup key not found")
|
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
// the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file)
|
// the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file)
|
||||||
|
72
management/server/status/error.go
Normal file
72
management/server/status/error.go
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
package status
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// UserAlreadyExists indicates that user already exists
|
||||||
|
UserAlreadyExists Type = 1
|
||||||
|
|
||||||
|
// PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled
|
||||||
|
PreconditionFailed Type = 2
|
||||||
|
|
||||||
|
// PermissionDenied indicates that user has no permissions to view data
|
||||||
|
PermissionDenied Type = 3
|
||||||
|
|
||||||
|
// NotFound indicates that the object wasn't found in the system (or under a given Account)
|
||||||
|
NotFound Type = 4
|
||||||
|
|
||||||
|
// Internal indicates some generic internal error
|
||||||
|
Internal Type = 5
|
||||||
|
|
||||||
|
// InvalidArgument indicates some generic invalid argument error
|
||||||
|
InvalidArgument Type = 6
|
||||||
|
|
||||||
|
// AlreadyExists indicates a generic error when an object already exists in the system
|
||||||
|
AlreadyExists Type = 7
|
||||||
|
|
||||||
|
// Unauthorized indicates that user is not authorized
|
||||||
|
Unauthorized Type = 8
|
||||||
|
|
||||||
|
// BadRequest indicates that user is not authorized
|
||||||
|
BadRequest Type = 9
|
||||||
|
)
|
||||||
|
|
||||||
|
// Type is a type of the Error
|
||||||
|
type Type int32
|
||||||
|
|
||||||
|
// Error is an internal error
|
||||||
|
type Error struct {
|
||||||
|
ErrorType Type
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type returns the Type of the error
|
||||||
|
func (e *Error) Type() Type {
|
||||||
|
return e.ErrorType
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error is an error string
|
||||||
|
func (e *Error) Error() string {
|
||||||
|
return e.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
// Errorf returns Error(ErrorType, fmt.Sprintf(format, a...)).
|
||||||
|
func Errorf(errorType Type, format string, a ...interface{}) error {
|
||||||
|
return &Error{
|
||||||
|
ErrorType: errorType,
|
||||||
|
Message: fmt.Sprintf(format, a...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromError returns Error, true if the provided error is of type of Error. nil, false otherwise
|
||||||
|
func FromError(err error) (s *Error, ok bool) {
|
||||||
|
if err == nil {
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
if e, ok := err.(*Error); ok {
|
||||||
|
return e, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
@ -3,8 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
"google.golang.org/grpc/codes"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
@ -123,7 +122,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
|
|||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
if am.idpManager == nil {
|
if am.idpManager == nil {
|
||||||
return nil, Errorf(PreconditionFailed, "IdP manager must be enabled to send user invites")
|
return nil, status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites")
|
||||||
}
|
}
|
||||||
|
|
||||||
if invite == nil {
|
if invite == nil {
|
||||||
@ -132,7 +131,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, Errorf(AccountNotFound, "account %s doesn't exist", accountID)
|
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if the user is already registered with this email => reject
|
// check if the user is already registered with this email => reject
|
||||||
@ -142,7 +141,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if user != nil {
|
if user != nil {
|
||||||
return nil, Errorf(UserAlreadyExists, "user has an existing account")
|
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
|
||||||
}
|
}
|
||||||
|
|
||||||
users, err := am.idpManager.GetUserByEmail(invite.Email)
|
users, err := am.idpManager.GetUserByEmail(invite.Email)
|
||||||
@ -151,7 +150,7 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(users) > 0 {
|
if len(users) > 0 {
|
||||||
return nil, Errorf(UserAlreadyExists, "user has an existing account")
|
return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account")
|
||||||
}
|
}
|
||||||
|
|
||||||
idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID)
|
idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID)
|
||||||
@ -188,25 +187,24 @@ func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*User
|
|||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
if update == nil {
|
if update == nil {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "provided user update is nil")
|
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountID)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, newGroupID := range update.AutoGroups {
|
for _, newGroupID := range update.AutoGroups {
|
||||||
if _, ok := account.Groups[newGroupID]; !ok {
|
if _, ok := account.Groups[newGroupID]; !ok {
|
||||||
return nil,
|
return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
||||||
status.Errorf(codes.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
newGroupID, update.Id)
|
||||||
newGroupID, update.Id)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
oldUser := account.Users[update.Id]
|
oldUser := account.Users[update.Id]
|
||||||
if oldUser == nil {
|
if oldUser == nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "update not found")
|
return nil, status.Errorf(status.NotFound, "update not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
// only auto groups, revoked status, and name can be updated for now
|
// only auto groups, revoked status, and name can be updated for now
|
||||||
@ -226,7 +224,7 @@ func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*User
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if userData == nil {
|
if userData == nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "user %s not found in the IdP", newUser.Id)
|
return nil, status.Errorf(status.NotFound, "user %s not found in the IdP", newUser.Id)
|
||||||
}
|
}
|
||||||
return newUser.toUserInfo(userData)
|
return newUser.toUserInfo(userData)
|
||||||
}
|
}
|
||||||
@ -242,14 +240,14 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
|
|||||||
|
|
||||||
account, err := am.Store.GetAccountByUser(userID)
|
account, err := am.Store.GetAccountByUser(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||||
account, err = am.newAccount(userID, lowerDomain)
|
account, err = am.newAccount(userID, lowerDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = am.Store.SaveAccount(account)
|
err = am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "failed creating account")
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// other error
|
// other error
|
||||||
@ -263,7 +261,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
|
|||||||
account.Domain = lowerDomain
|
account.Domain = lowerDomain
|
||||||
err = am.Store.SaveAccount(account)
|
err = am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "failed updating account with domain")
|
return nil, status.Errorf(status.Internal, "failed updating account with domain")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -272,14 +270,14 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string)
|
|||||||
|
|
||||||
// IsUserAdmin flag for current user authenticated by JWT token
|
// IsUserAdmin flag for current user authenticated by JWT token
|
||||||
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
|
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
|
||||||
account, err := am.GetAccountFromToken(claims)
|
account, _, err := am.GetAccountFromToken(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("get account: %v", err)
|
return false, fmt.Errorf("get account: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
user, ok := account.Users[claims.UserId]
|
user, ok := account.Users[claims.UserId]
|
||||||
if !ok {
|
if !ok {
|
||||||
return false, fmt.Errorf("no such user")
|
return false, status.Errorf(status.NotFound, "user not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
return user.Role == UserRoleAdmin, nil
|
return user.Role == UserRoleAdmin, nil
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
package route
|
package route
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"google.golang.org/grpc/codes"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -108,13 +107,13 @@ func (r *Route) IsEqual(other *Route) bool {
|
|||||||
func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) {
|
func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) {
|
||||||
prefix, err := netip.ParsePrefix(networkString)
|
prefix, err := netip.ParsePrefix(networkString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return InvalidNetwork, netip.Prefix{}, err
|
return InvalidNetwork, netip.Prefix{}, status.Errorf(status.InvalidArgument, "invalid network %s", networkString)
|
||||||
}
|
}
|
||||||
|
|
||||||
masked := prefix.Masked()
|
masked := prefix.Masked()
|
||||||
|
|
||||||
if !masked.IsValid() {
|
if !masked.IsValid() {
|
||||||
return InvalidNetwork, netip.Prefix{}, status.Errorf(codes.InvalidArgument, "invalid range %s", networkString)
|
return InvalidNetwork, netip.Prefix{}, status.Errorf(status.InvalidArgument, "invalid range %s", networkString)
|
||||||
}
|
}
|
||||||
|
|
||||||
if masked.Addr().Is6() {
|
if masked.Addr().Is6() {
|
||||||
|
Loading…
Reference in New Issue
Block a user