Feat peer groups (#304)

* feat(management): add groups

* squash

* feat(management): add handlers for groups

* feat(management): add handlers for groups

* chore(management): add tests for the get group of the management

* chore(management): add tests for save group
This commit is contained in:
Givi Khojanashvili 2022-05-03 18:02:51 +04:00 committed by GitHub
parent 70ffc9d625
commit 219888254e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 737 additions and 65 deletions

View File

@ -1,6 +1,9 @@
package server
import (
"strings"
"sync"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/util"
@ -8,8 +11,6 @@ import (
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"strings"
"sync"
)
const (
@ -21,7 +22,12 @@ const (
type AccountManager interface {
GetOrCreateAccountByUser(userId, domain string) (*Account, error)
GetAccountByUser(userId string) (*Account, error)
AddSetupKey(accountId string, keyName string, keyType SetupKeyType, expiresIn *util.Duration) (*SetupKey, error)
AddSetupKey(
accountId string,
keyName string,
keyType SetupKeyType,
expiresIn *util.Duration,
) (*SetupKey, error)
RevokeSetupKey(accountId string, keyId string) (*SetupKey, error)
RenameSetupKey(accountId string, keyId string, newName string) (*SetupKey, error)
GetAccountById(accountId string) (*Account, error)
@ -36,6 +42,13 @@ type AccountManager interface {
GetPeerByIP(accountId string, peerIP string) (*Peer, error)
GetNetworkMap(peerKey string) (*NetworkMap, error)
AddPeer(setupKey string, peer *Peer) (*Peer, error)
GetGroup(accountId, groupID string) (*Group, error)
SaveGroup(accountId string, group *Group) error
DeleteGroup(accountId, groupID string) error
ListGroups(accountId string) ([]*Group, error)
GroupAddPeer(accountId, groupID, peerKey string) error
GroupDeletePeer(accountId, groupID, peerKey string) error
GroupListPeers(accountId, groupID string) ([]*Peer, error)
}
type DefaultAccountManager struct {
@ -58,6 +71,7 @@ type Account struct {
Network *Network
Peers map[string]*Peer
Users map[string]*User
Groups map[string]*Group
}
// NewAccount creates a new Account with a generated ID and generated default setup keys
@ -93,7 +107,11 @@ func (a *Account) Copy() *Account {
}
// NewManager creates a new DefaultAccountManager with a provided Store
func NewManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager) *DefaultAccountManager {
func NewManager(
store Store,
peersUpdateManager *PeersUpdateManager,
idpManager idp.Manager,
) *DefaultAccountManager {
return &DefaultAccountManager{
Store: store,
mux: sync.Mutex{},
@ -102,8 +120,13 @@ func NewManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager
}
}
//AddSetupKey generates a new setup key with a given name and type, and adds it to the specified account
func (am *DefaultAccountManager) AddSetupKey(accountId string, keyName string, keyType SetupKeyType, expiresIn *util.Duration) (*SetupKey, error) {
// AddSetupKey generates a new setup key with a given name and type, and adds it to the specified account
func (am *DefaultAccountManager) AddSetupKey(
accountId string,
keyName string,
keyType SetupKeyType,
expiresIn *util.Duration,
) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
@ -128,7 +151,7 @@ func (am *DefaultAccountManager) AddSetupKey(accountId string, keyName string, k
return setupKey, nil
}
//RevokeSetupKey marks SetupKey as revoked - becomes not valid anymore
// RevokeSetupKey marks SetupKey as revoked - becomes not valid anymore
func (am *DefaultAccountManager) RevokeSetupKey(accountId string, keyId string) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
@ -154,8 +177,12 @@ func (am *DefaultAccountManager) RevokeSetupKey(accountId string, keyId string)
return keyCopy, nil
}
//RenameSetupKey renames existing setup key of the specified account.
func (am *DefaultAccountManager) RenameSetupKey(accountId string, keyId string, newName string) (*SetupKey, error) {
// RenameSetupKey renames existing setup key of the specified account.
func (am *DefaultAccountManager) RenameSetupKey(
accountId string,
keyId string,
newName string,
) (*SetupKey, error) {
am.mux.Lock()
defer am.mux.Unlock()
@ -180,7 +207,7 @@ func (am *DefaultAccountManager) RenameSetupKey(accountId string, keyId string,
return keyCopy, nil
}
//GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist
// GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist
func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, error) {
am.mux.Lock()
defer am.mux.Unlock()
@ -193,10 +220,11 @@ func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, err
return account, nil
}
//GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and
// GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and
// user id doesn't have an account associated with it, one account is created
func (am *DefaultAccountManager) GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error) {
func (am *DefaultAccountManager) GetAccountByUserOrAccountId(
userId, accountId, domain string,
) (*Account, error) {
if accountId != "" {
return am.GetAccountById(accountId)
} else if userId != "" {
@ -219,14 +247,22 @@ func (am *DefaultAccountManager) updateIDPMetadata(userId, accountID string) err
if am.idpManager != nil {
err := am.idpManager.UpdateUserAppMetadata(userId, idp.AppMetadata{WTAccountId: accountID})
if err != nil {
return status.Errorf(codes.Internal, "updating user's app metadata failed with: %v", err)
return status.Errorf(
codes.Internal,
"updating user's app metadata failed with: %v",
err,
)
}
}
return nil
}
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims, primaryDomain bool) error {
func (am *DefaultAccountManager) updateAccountDomainAttributes(
account *Account,
claims jwtclaims.AuthorizationClaims,
primaryDomain bool,
) error {
account.IsDomainPrimaryAccount = primaryDomain
account.Domain = strings.ToLower(claims.Domain)
account.DomainCategory = claims.DomainCategory
@ -245,7 +281,11 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account,
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
// was previously unclassified or classified as public so N users that logged int that time, has they own account
// and peers that shouldn't be lost.
func (am *DefaultAccountManager) handleExistingUserAccount(existingAcc *Account, domainAcc *Account, claims jwtclaims.AuthorizationClaims) error {
func (am *DefaultAccountManager) handleExistingUserAccount(
existingAcc *Account,
domainAcc *Account,
claims jwtclaims.AuthorizationClaims,
) error {
var err error
if domainAcc != nil && existingAcc.Id != domainAcc.Id {
@ -271,7 +311,10 @@ func (am *DefaultAccountManager) handleExistingUserAccount(existingAcc *Account,
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
// otherwise it will create a new account and make it primary account for the domain.
func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
func (am *DefaultAccountManager) handleNewUserAccount(
domainAcc *Account,
claims jwtclaims.AuthorizationClaims,
) (*Account, error) {
var (
account *Account
err error
@ -315,7 +358,9 @@ func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
//
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) {
func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims(
claims jwtclaims.AuthorizationClaims,
) (*Account, error) {
// if Account ID is part of the claims
// it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory {
@ -355,7 +400,7 @@ func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims(claims jwtcla
}
}
//AccountExists checks whether account exists (returns true) or not (returns false)
// AccountExists checks whether account exists (returns true) or not (returns false)
func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) {
am.mux.Lock()
defer am.mux.Unlock()
@ -377,12 +422,10 @@ func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error)
// AddAccount generates a new Account with a provided accountId and userId, saves to the Store
func (am *DefaultAccountManager) AddAccount(accountId, userId, domain string) (*Account, error) {
am.mux.Lock()
defer am.mux.Unlock()
return am.createAccount(accountId, userId, domain)
}
func (am *DefaultAccountManager) createAccount(accountId, userId, domain string) (*Account, error) {
@ -398,7 +441,6 @@ func (am *DefaultAccountManager) createAccount(accountId, userId, domain string)
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
func newAccountWithId(accountId, userId, domain string) *Account {
log.Debugf("creating new account")
setupKeys := make(map[string]*SetupKey)

View File

@ -28,8 +28,7 @@ type FileStore struct {
storeFile string `json:"-"`
}
type StoredAccount struct {
}
type StoredAccount struct{}
// NewStore restores a store from the file located in the datadir
func NewStore(dataDir string) (*FileStore, error) {
@ -39,7 +38,6 @@ func NewStore(dataDir string) (*FileStore, error) {
// restore restores the state of the store from the file.
// Creates a new empty store file if doesn't exist
func restore(file string) (*FileStore, error) {
if _, err := os.Stat(file); os.IsNotExist(err) {
// create a new FileStore if previously didn't exist (e.g. first run)
s := &FileStore{
@ -109,12 +107,7 @@ func (s *FileStore) SavePeer(accountId string, peer *Peer) error {
}
account.Peers[peer.Key] = peer
err = s.persist(s.storeFile)
if err != nil {
return err
}
return nil
return s.persist(s.storeFile)
}
// DeletePeer deletes peer from the Store
@ -140,7 +133,7 @@ func (s *FileStore) DeletePeer(accountId string, peerKey string) (*Peer, error)
return nil, err
}
return peer, err
return peer, nil
}
// GetPeer returns a peer from a Store
@ -191,16 +184,10 @@ func (s *FileStore) SaveAccount(account *Account) error {
s.PrivateDomain2AccountId[account.Domain] = account.Id
}
err := s.persist(s.storeFile)
if err != nil {
return err
}
return nil
return s.persist(s.storeFile)
}
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
accountId, accountIdFound := s.PrivateDomain2AccountId[strings.ToLower(domain)]
if !accountIdFound {
return nil, status.Errorf(codes.NotFound, "provided domain is not registered or is not private")
@ -215,7 +202,6 @@ func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
}
func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
accountId, accountIdFound := s.SetupKeyId2AccountId[strings.ToUpper(setupKey)]
if !accountIdFound {
return nil, status.Errorf(codes.NotFound, "provided setup key doesn't exists")
@ -228,6 +214,7 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
return account, nil
}
func (s *FileStore) GetAccountPeers(accountId string) ([]*Peer, error) {
s.mux.Lock()
defer s.mux.Unlock()
@ -246,7 +233,6 @@ func (s *FileStore) GetAccountPeers(accountId string) ([]*Peer, error) {
}
func (s *FileStore) GetAccount(accountId string) (*Account, error) {
account, accountFound := s.Accounts[accountId]
if !accountFound {
return nil, status.Errorf(codes.NotFound, "account not found")
@ -278,3 +264,19 @@ func (s *FileStore) GetPeerAccount(peerKey string) (*Account, error) {
return s.GetAccount(accountId)
}
func (s *FileStore) GetGroup(groupID string) (*Group, error) {
return nil, nil
}
func (s *FileStore) SaveGroup(group *Group) error {
return nil
}
func (s *FileStore) DeleteGroup(groupID string) error {
return nil
}
func (s *FileStore) ListGroups() ([]*Group, error) {
return nil, nil
}

163
management/server/group.go Normal file
View File

@ -0,0 +1,163 @@
package server
import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// Group of the peers for ACL
type Group struct {
// ID of the group
ID string
// Name visible in the UI
Name string
// Peers list of the group
Peers []string
}
// GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
group, ok := account.Groups[groupID]
if ok {
return group, nil
}
return nil, status.Errorf(codes.NotFound, "group with ID %s not found", groupID)
}
// SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
}
account.Groups[group.ID] = group
return am.Store.SaveAccount(account)
}
// DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
}
delete(account.Groups, groupID)
return am.Store.SaveAccount(account)
}
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
groups := make([]*Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}
return groups, nil
}
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string) error {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
}
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(codes.NotFound, "group with ID %s not found", groupID)
}
add := true
for _, itemID := range group.Peers {
if itemID == peerKey {
add = false
break
}
}
if add {
group.Peers = append(group.Peers, peerKey)
}
return am.Store.SaveAccount(account)
}
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey string) error {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
}
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(codes.NotFound, "group with ID %s not found", groupID)
}
for i, itemID := range group.Peers {
if itemID == peerKey {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
return am.Store.SaveAccount(account)
}
}
return nil
}
// GroupListPeers returns list of the peers from the group
func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
group, ok := account.Groups[groupID]
if !ok {
return nil, status.Errorf(codes.NotFound, "group with ID %s not found", groupID)
}
peers := make([]*Peer, 0, len(account.Groups))
for _, peerID := range group.Peers {
p, ok := account.Peers[peerID]
if ok {
peers = append(peers, p)
}
}
return peers, nil
}

View File

@ -0,0 +1,135 @@
package handler
import (
"encoding/json"
"fmt"
"net/http"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/rs/xid"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
)
// Groups is a handler that returns groups of the account
type Groups struct {
accountManager server.AccountManager
authAudience string
jwtExtractor jwtclaims.ClaimsExtractor
}
// GroupResponse is a response sent to the client
type GroupResponse struct {
ID string
Name string
Peers []string
}
func NewGroups(accountManager server.AccountManager, authAudience string) *Groups {
return &Groups{
accountManager: accountManager,
authAudience: authAudience,
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
}
}
// GetAllGroupsHandler list for the account
func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
account, err := h.getGroupAccount(r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
writeJSONObject(w, account.Groups)
}
func (h *Groups) CreateOrUpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := h.getGroupAccount(r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
var req server.Group
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if r.Method == http.MethodPost {
req.ID = xid.New().String()
}
if err := h.accountManager.SaveGroup(account.Id, &req); err != nil {
log.Errorf("failed updating group %s under account %s %v", req.ID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
writeJSONObject(w, &req)
}
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := h.getGroupAccount(r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
aID := account.Id
gID := mux.Vars(r)["id"]
if len(gID) == 0 {
http.Error(w, "invalid group ID", http.StatusBadRequest)
return
}
if err := h.accountManager.DeleteGroup(aID, gID); err != nil {
log.Errorf("failed delete group %s under account %s %v", gID, aID, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
writeJSONObject(w, "")
}
func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := h.getGroupAccount(r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
switch r.Method {
case http.MethodGet:
groupID := mux.Vars(r)["id"]
if len(groupID) == 0 {
http.Error(w, "invalid group ID", http.StatusBadRequest)
return
}
group, err := h.accountManager.GetGroup(account.Id, groupID)
if err != nil {
http.Error(w, "group not found", http.StatusNotFound)
return
}
writeJSONObject(w, group)
default:
http.Error(w, "", http.StatusNotFound)
}
}
func (h *Groups) getGroupAccount(r *http.Request) (*server.Account, error) {
jwtClaims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, err := h.accountManager.GetAccountWithAuthorizationClaims(jwtClaims)
if err != nil {
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
}
return account, nil
}

View File

@ -0,0 +1,202 @@
package handler
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
)
func initGroupTestData(groups ...*server.Group) *Groups {
return &Groups{
accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(accountID string, group *server.Group) error {
if !strings.HasPrefix(group.ID, "id-") {
group.ID = "id-was-set"
}
return nil
},
GetGroupFunc: func(_, groupID string) (*server.Group, error) {
if groupID != "idofthegroup" {
return nil, fmt.Errorf("not found")
}
return &server.Group{
ID: "idofthegroup",
Name: "Group",
}, nil
},
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
}, nil
},
},
authAudience: "",
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
}
},
},
}
}
func TestGetGroup(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
requestType string
requestPath string
requestBody io.Reader
}{
{
name: "GetGroup OK",
expectedBody: true,
requestType: http.MethodGet,
requestPath: "/api/groups/idofthegroup",
expectedStatus: http.StatusOK,
},
{
name: "GetGroup not found",
requestType: http.MethodGet,
requestPath: "/api/groups/notexists",
expectedStatus: http.StatusNotFound,
},
}
group := &server.Group{
ID: "idofthegroup",
Name: "Group",
}
p := initGroupTestData(group)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/groups/{id}", p.GetGroupHandler).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v",
status, tc.expectedStatus)
return
}
if !tc.expectedBody {
return
}
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
got := &server.Group{}
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, got.ID, group.ID)
assert.Equal(t, got.Name, group.Name)
})
}
}
func TestSaveGroup(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
expectedGroup *server.Group
requestType string
requestPath string
requestBody io.Reader
}{
{
name: "SaveGroup POST OK",
requestType: http.MethodPost,
requestPath: "/api/groups",
requestBody: bytes.NewBuffer(
[]byte(`{"Name":"Default POSTed Group"}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedGroup: &server.Group{
ID: "id-was-set",
Name: "Default POSTed Group",
},
},
{
name: "SaveGroup PUT OK",
requestType: http.MethodPut,
requestPath: "/api/groups",
requestBody: bytes.NewBuffer(
[]byte(`{"ID":"id-existed","Name":"Default POSTed Group"}`)),
expectedStatus: http.StatusOK,
expectedGroup: &server.Group{
ID: "id-existed",
Name: "Default POSTed Group",
},
},
}
p := initGroupTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/groups", p.CreateOrUpdateGroupHandler).Methods("PUT", "POST")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
if !tc.expectedBody {
return
}
got := &server.Group{}
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, got, tc.expectedGroup)
})
}
}

View File

@ -2,13 +2,14 @@ package handler
import (
"encoding/json"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"io"
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
@ -43,14 +44,19 @@ func initTestMetaData(peer ...*server.Peer) *Peers {
// Tests the GetPeers endpoint reachable in the route /api/peers
// Use the metadata generated by initTestMetaData() to check for values
func TestGetPeers(t *testing.T) {
var tt = []struct {
tt := []struct {
name string
expectedStatus int
requestType string
requestPath string
requestBody io.Reader
}{
{name: "GetPeersMetaData", requestType: http.MethodGet, requestPath: "/api/peers/", expectedStatus: http.StatusOK},
{
name: "GetPeersMetaData",
requestType: http.MethodGet,
requestPath: "/api/peers/",
expectedStatus: http.StatusOK,
},
}
rr := httptest.NewRecorder()

View File

@ -25,26 +25,44 @@ type Server struct {
// NewHttpsServer creates a new HTTPs server (with HTTPS support) and a certManager that is responsible for generating and renewing Let's Encrypt certificate
// The listening address will be :443 no matter what was specified in s.HttpServerConfig.Address
func NewHttpsServer(config *s.HttpServerConfig, certManager *autocert.Manager, accountManager s.AccountManager) *Server {
func NewHttpsServer(
config *s.HttpServerConfig,
certManager *autocert.Manager,
accountManager s.AccountManager,
) *Server {
server := &http.Server{
Addr: config.Address,
WriteTimeout: time.Second * 15,
ReadTimeout: time.Second * 15,
IdleTimeout: time.Second * 60,
}
return &Server{server: server, config: config, certManager: certManager, accountManager: accountManager}
return &Server{
server: server,
config: config,
certManager: certManager,
accountManager: accountManager,
}
}
// NewHttpsServerWithTLSConfig creates a new HTTPs server with a provided tls.Config.
// Usually used when you already have a certificate
func NewHttpsServerWithTLSConfig(config *s.HttpServerConfig, tlsConfig *tls.Config, accountManager s.AccountManager) *Server {
func NewHttpsServerWithTLSConfig(
config *s.HttpServerConfig,
tlsConfig *tls.Config,
accountManager s.AccountManager,
) *Server {
server := &http.Server{
Addr: config.Address,
WriteTimeout: time.Second * 15,
ReadTimeout: time.Second * 15,
IdleTimeout: time.Second * 60,
}
return &Server{server: server, config: config, tlsConfig: tlsConfig, accountManager: accountManager}
return &Server{
server: server,
config: config,
tlsConfig: tlsConfig,
accountManager: accountManager,
}
}
// NewHttpServer creates a new HTTP server (without HTTPS)
@ -63,8 +81,11 @@ func (s *Server) Stop(ctx context.Context) error {
// Start defines http handlers and starts the http server. Blocks until server is shutdown.
func (s *Server) Start() error {
jwtMiddleware, err := middleware.NewJwtMiddleware(s.config.AuthIssuer, s.config.AuthAudience, s.config.AuthKeysLocation)
jwtMiddleware, err := middleware.NewJwtMiddleware(
s.config.AuthIssuer,
s.config.AuthAudience,
s.config.AuthKeysLocation,
)
if err != nil {
return err
}
@ -74,19 +95,31 @@ func (s *Server) Start() error {
r := mux.NewRouter()
r.Use(jwtMiddleware.Handler, corsMiddleware.Handler)
groupsHandler := handler.NewGroups(s.accountManager, s.config.AuthAudience)
peersHandler := handler.NewPeers(s.accountManager, s.config.AuthAudience)
keysHandler := handler.NewSetupKeysHandler(s.accountManager, s.config.AuthAudience)
r.HandleFunc("/api/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS")
r.HandleFunc("/api/peers/{id}", peersHandler.HandlePeer).Methods("GET", "PUT", "DELETE", "OPTIONS")
r.HandleFunc("/api/peers/{id}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
r.HandleFunc("/api/setup-keys", keysHandler.GetKeys).Methods("GET", "POST", "OPTIONS")
r.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).Methods("GET", "PUT", "OPTIONS")
r.HandleFunc("/api/setup-keys", keysHandler.GetKeys).Methods("POST", "OPTIONS")
r.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).
Methods("GET", "PUT", "DELETE", "OPTIONS")
r.HandleFunc("/api/groups", groupsHandler.GetAllGroupsHandler).Methods("GET", "OPTIONS")
r.HandleFunc("/api/groups", groupsHandler.CreateOrUpdateGroupHandler).
Methods("POST", "PUT", "OPTIONS")
r.HandleFunc("/api/groups/{id}", groupsHandler.GetGroupHandler).Methods("GET", "OPTIONS")
r.HandleFunc("/api/groups/{id}", groupsHandler.DeleteGroupHandler).Methods("DELETE", "OPTIONS")
http.Handle("/", r)
if s.certManager != nil {
// if HTTPS is enabled we reuse the listener from the cert manager
listener := s.certManager.Listener()
log.Infof("HTTPs server listening on %s with Let's Encrypt autocert configured", listener.Addr())
log.Infof(
"HTTPs server listening on %s with Let's Encrypt autocert configured",
listener.Addr(),
)
if err = http.Serve(listener, s.certManager.HTTPHandler(r)); err != nil {
log.Errorf("failed to serve https server: %v", err)
return err

View File

@ -26,13 +26,25 @@ type MockAccountManager struct {
GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error)
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
AddPeerFunc func(setupKey string, peer *server.Peer) (*server.Peer, error)
GetGroupFunc func(accountID, groupID string) (*server.Group, error)
SaveGroupFunc func(accountID string, group *server.Group) error
DeleteGroupFunc func(accountID, groupID string) error
ListGroupsFunc func(accountID string) ([]*server.Group, error)
GroupAddPeerFunc func(accountID, groupID, peerKey string) error
GroupDeletePeerFunc func(accountID, groupID, peerKey string) error
GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error)
}
func (am *MockAccountManager) GetOrCreateAccountByUser(userId, domain string) (*server.Account, error) {
func (am *MockAccountManager) GetOrCreateAccountByUser(
userId, domain string,
) (*server.Account, error) {
if am.GetOrCreateAccountByUserFunc != nil {
return am.GetOrCreateAccountByUserFunc(userId, domain)
}
return nil, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByUser not implemented")
return nil, status.Errorf(
codes.Unimplemented,
"method GetOrCreateAccountByUser not implemented",
)
}
func (am *MockAccountManager) GetAccountByUser(userId string) (*server.Account, error) {
@ -42,21 +54,33 @@ func (am *MockAccountManager) GetAccountByUser(userId string) (*server.Account,
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUser not implemented")
}
func (am *MockAccountManager) AddSetupKey(accountId string, keyName string, keyType server.SetupKeyType, expiresIn *util.Duration) (*server.SetupKey, error) {
func (am *MockAccountManager) AddSetupKey(
accountId string,
keyName string,
keyType server.SetupKeyType,
expiresIn *util.Duration,
) (*server.SetupKey, error) {
if am.AddSetupKeyFunc != nil {
return am.AddSetupKeyFunc(accountId, keyName, keyType, expiresIn)
}
return nil, status.Errorf(codes.Unimplemented, "method AddSetupKey not implemented")
}
func (am *MockAccountManager) RevokeSetupKey(accountId string, keyId string) (*server.SetupKey, error) {
func (am *MockAccountManager) RevokeSetupKey(
accountId string,
keyId string,
) (*server.SetupKey, error) {
if am.RevokeSetupKeyFunc != nil {
return am.RevokeSetupKeyFunc(accountId, keyId)
}
return nil, status.Errorf(codes.Unimplemented, "method RevokeSetupKey not implemented")
}
func (am *MockAccountManager) RenameSetupKey(accountId string, keyId string, newName string) (*server.SetupKey, error) {
func (am *MockAccountManager) RenameSetupKey(
accountId string,
keyId string,
newName string,
) (*server.SetupKey, error) {
if am.RenameSetupKeyFunc != nil {
return am.RenameSetupKeyFunc(accountId, keyId, newName)
}
@ -70,18 +94,28 @@ func (am *MockAccountManager) GetAccountById(accountId string) (*server.Account,
return nil, status.Errorf(codes.Unimplemented, "method GetAccountById not implemented")
}
func (am *MockAccountManager) GetAccountByUserOrAccountId(userId, accountId, domain string) (*server.Account, error) {
func (am *MockAccountManager) GetAccountByUserOrAccountId(
userId, accountId, domain string,
) (*server.Account, error) {
if am.GetAccountByUserOrAccountIdFunc != nil {
return am.GetAccountByUserOrAccountIdFunc(userId, accountId, domain)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUserOrAccountId not implemented")
return nil, status.Errorf(
codes.Unimplemented,
"method GetAccountByUserOrAccountId not implemented",
)
}
func (am *MockAccountManager) GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
func (am *MockAccountManager) GetAccountWithAuthorizationClaims(
claims jwtclaims.AuthorizationClaims,
) (*server.Account, error) {
if am.GetAccountWithAuthorizationClaimsFunc != nil {
return am.GetAccountWithAuthorizationClaimsFunc(claims)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountWithAuthorizationClaims not implemented")
return nil, status.Errorf(
codes.Unimplemented,
"method GetAccountWithAuthorizationClaims not implemented",
)
}
func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) {
@ -91,7 +125,9 @@ func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) {
return nil, status.Errorf(codes.Unimplemented, "method AccountExists not implemented")
}
func (am *MockAccountManager) AddAccount(accountId, userId, domain string) (*server.Account, error) {
func (am *MockAccountManager) AddAccount(
accountId, userId, domain string,
) (*server.Account, error) {
if am.AddAccountFunc != nil {
return am.AddAccountFunc(accountId, userId, domain)
}
@ -112,7 +148,11 @@ func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool)
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected not implemented")
}
func (am *MockAccountManager) RenamePeer(accountId string, peerKey string, newName string) (*server.Peer, error) {
func (am *MockAccountManager) RenamePeer(
accountId string,
peerKey string,
newName string,
) (*server.Peer, error) {
if am.RenamePeerFunc != nil {
return am.RenamePeerFunc(accountId, peerKey, newName)
}
@ -146,3 +186,52 @@ func (am *MockAccountManager) AddPeer(setupKey string, peer *server.Peer) (*serv
}
return nil, status.Errorf(codes.Unimplemented, "method AddPeer not implemented")
}
func (am *MockAccountManager) GetGroup(accountID, groupID string) (*server.Group, error) {
if am.GetGroupFunc != nil {
return am.GetGroupFunc(accountID, groupID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetGroup not implemented")
}
func (am *MockAccountManager) SaveGroup(accountID string, group *server.Group) error {
if am.SaveGroupFunc != nil {
return am.SaveGroupFunc(accountID, group)
}
return status.Errorf(codes.Unimplemented, "method UpdateGroup not implemented")
}
func (am *MockAccountManager) DeleteGroup(accountID, groupID string) error {
if am.DeleteGroupFunc != nil {
return am.DeleteGroupFunc(accountID, groupID)
}
return status.Errorf(codes.Unimplemented, "method DeleteGroup not implemented")
}
func (am *MockAccountManager) ListGroups(accountID string) ([]*server.Group, error) {
if am.ListGroupsFunc != nil {
return am.ListGroupsFunc(accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListGroups not implemented")
}
func (am *MockAccountManager) GroupAddPeer(accountID, groupID, peerKey string) error {
if am.GroupAddPeerFunc != nil {
return am.GroupAddPeerFunc(accountID, groupID, peerKey)
}
return status.Errorf(codes.Unimplemented, "method GroupAddPeer not implemented")
}
func (am *MockAccountManager) GroupDeletePeer(accountID, groupID, peerKey string) error {
if am.GroupDeletePeerFunc != nil {
return am.GroupDeletePeerFunc(accountID, groupID, peerKey)
}
return status.Errorf(codes.Unimplemented, "method GroupDeletePeer not implemented")
}
func (am *MockAccountManager) GroupListPeers(accountID, groupID string) ([]*server.Peer, error) {
if am.GroupListPeersFunc != nil {
return am.GroupListPeersFunc(accountID, groupID)
}
return nil, status.Errorf(codes.Unimplemented, "method GroupListPeers not implemented")
}