From 219888254e398746746a08acf224689bb80b46a6 Mon Sep 17 00:00:00 2001 From: Givi Khojanashvili Date: Tue, 3 May 2022 18:02:51 +0400 Subject: [PATCH] Feat peer groups (#304) * feat(management): add groups * squash * feat(management): add handlers for groups * feat(management): add handlers for groups * chore(management): add tests for the get group of the management * chore(management): add tests for save group --- management/server/account.go | 86 ++++++-- management/server/file_store.go | 40 ++-- management/server/group.go | 163 ++++++++++++++ management/server/http/handler/groups.go | 135 ++++++++++++ management/server/http/handler/groups_test.go | 202 ++++++++++++++++++ management/server/http/handler/peers_test.go | 12 +- management/server/http/server.go | 53 ++++- management/server/mock_server/account_mock.go | 111 +++++++++- 8 files changed, 737 insertions(+), 65 deletions(-) create mode 100644 management/server/group.go create mode 100644 management/server/http/handler/groups.go create mode 100644 management/server/http/handler/groups_test.go diff --git a/management/server/account.go b/management/server/account.go index eebcc51c9..3c58a7340 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1,6 +1,9 @@ package server import ( + "strings" + "sync" + "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/util" @@ -8,8 +11,6 @@ import ( log "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "strings" - "sync" ) const ( @@ -21,7 +22,12 @@ const ( type AccountManager interface { GetOrCreateAccountByUser(userId, domain string) (*Account, error) GetAccountByUser(userId string) (*Account, error) - AddSetupKey(accountId string, keyName string, keyType SetupKeyType, expiresIn *util.Duration) (*SetupKey, error) + AddSetupKey( + accountId string, + keyName string, + keyType SetupKeyType, + expiresIn *util.Duration, + ) (*SetupKey, error) RevokeSetupKey(accountId string, keyId string) (*SetupKey, error) RenameSetupKey(accountId string, keyId string, newName string) (*SetupKey, error) GetAccountById(accountId string) (*Account, error) @@ -36,6 +42,13 @@ type AccountManager interface { GetPeerByIP(accountId string, peerIP string) (*Peer, error) GetNetworkMap(peerKey string) (*NetworkMap, error) AddPeer(setupKey string, peer *Peer) (*Peer, error) + GetGroup(accountId, groupID string) (*Group, error) + SaveGroup(accountId string, group *Group) error + DeleteGroup(accountId, groupID string) error + ListGroups(accountId string) ([]*Group, error) + GroupAddPeer(accountId, groupID, peerKey string) error + GroupDeletePeer(accountId, groupID, peerKey string) error + GroupListPeers(accountId, groupID string) ([]*Peer, error) } type DefaultAccountManager struct { @@ -58,6 +71,7 @@ type Account struct { Network *Network Peers map[string]*Peer Users map[string]*User + Groups map[string]*Group } // NewAccount creates a new Account with a generated ID and generated default setup keys @@ -93,7 +107,11 @@ func (a *Account) Copy() *Account { } // NewManager creates a new DefaultAccountManager with a provided Store -func NewManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager) *DefaultAccountManager { +func NewManager( + store Store, + peersUpdateManager *PeersUpdateManager, + idpManager idp.Manager, +) *DefaultAccountManager { return &DefaultAccountManager{ Store: store, mux: sync.Mutex{}, @@ -102,8 +120,13 @@ func NewManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager } } -//AddSetupKey generates a new setup key with a given name and type, and adds it to the specified account -func (am *DefaultAccountManager) AddSetupKey(accountId string, keyName string, keyType SetupKeyType, expiresIn *util.Duration) (*SetupKey, error) { +// AddSetupKey generates a new setup key with a given name and type, and adds it to the specified account +func (am *DefaultAccountManager) AddSetupKey( + accountId string, + keyName string, + keyType SetupKeyType, + expiresIn *util.Duration, +) (*SetupKey, error) { am.mux.Lock() defer am.mux.Unlock() @@ -128,7 +151,7 @@ func (am *DefaultAccountManager) AddSetupKey(accountId string, keyName string, k return setupKey, nil } -//RevokeSetupKey marks SetupKey as revoked - becomes not valid anymore +// RevokeSetupKey marks SetupKey as revoked - becomes not valid anymore func (am *DefaultAccountManager) RevokeSetupKey(accountId string, keyId string) (*SetupKey, error) { am.mux.Lock() defer am.mux.Unlock() @@ -154,8 +177,12 @@ func (am *DefaultAccountManager) RevokeSetupKey(accountId string, keyId string) return keyCopy, nil } -//RenameSetupKey renames existing setup key of the specified account. -func (am *DefaultAccountManager) RenameSetupKey(accountId string, keyId string, newName string) (*SetupKey, error) { +// RenameSetupKey renames existing setup key of the specified account. +func (am *DefaultAccountManager) RenameSetupKey( + accountId string, + keyId string, + newName string, +) (*SetupKey, error) { am.mux.Lock() defer am.mux.Unlock() @@ -180,7 +207,7 @@ func (am *DefaultAccountManager) RenameSetupKey(accountId string, keyId string, return keyCopy, nil } -//GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist +// GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, error) { am.mux.Lock() defer am.mux.Unlock() @@ -193,10 +220,11 @@ func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, err return account, nil } -//GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and +// GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and // user id doesn't have an account associated with it, one account is created -func (am *DefaultAccountManager) GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error) { - +func (am *DefaultAccountManager) GetAccountByUserOrAccountId( + userId, accountId, domain string, +) (*Account, error) { if accountId != "" { return am.GetAccountById(accountId) } else if userId != "" { @@ -219,14 +247,22 @@ func (am *DefaultAccountManager) updateIDPMetadata(userId, accountID string) err if am.idpManager != nil { err := am.idpManager.UpdateUserAppMetadata(userId, idp.AppMetadata{WTAccountId: accountID}) if err != nil { - return status.Errorf(codes.Internal, "updating user's app metadata failed with: %v", err) + return status.Errorf( + codes.Internal, + "updating user's app metadata failed with: %v", + err, + ) } } return nil } // updateAccountDomainAttributes updates the account domain attributes and then, saves the account -func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims, primaryDomain bool) error { +func (am *DefaultAccountManager) updateAccountDomainAttributes( + account *Account, + claims jwtclaims.AuthorizationClaims, + primaryDomain bool, +) error { account.IsDomainPrimaryAccount = primaryDomain account.Domain = strings.ToLower(claims.Domain) account.DomainCategory = claims.DomainCategory @@ -245,7 +281,11 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, // non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain // was previously unclassified or classified as public so N users that logged int that time, has they own account // and peers that shouldn't be lost. -func (am *DefaultAccountManager) handleExistingUserAccount(existingAcc *Account, domainAcc *Account, claims jwtclaims.AuthorizationClaims) error { +func (am *DefaultAccountManager) handleExistingUserAccount( + existingAcc *Account, + domainAcc *Account, + claims jwtclaims.AuthorizationClaims, +) error { var err error if domainAcc != nil && existingAcc.Id != domainAcc.Id { @@ -271,7 +311,10 @@ func (am *DefaultAccountManager) handleExistingUserAccount(existingAcc *Account, // handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // otherwise it will create a new account and make it primary account for the domain. -func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) { +func (am *DefaultAccountManager) handleNewUserAccount( + domainAcc *Account, + claims jwtclaims.AuthorizationClaims, +) (*Account, error) { var ( account *Account err error @@ -315,7 +358,9 @@ func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims // Existing user + Existing account + Existing Indexed Domain -> Nothing changes // // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) -func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) { +func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims( + claims jwtclaims.AuthorizationClaims, +) (*Account, error) { // if Account ID is part of the claims // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory { @@ -355,7 +400,7 @@ func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims(claims jwtcla } } -//AccountExists checks whether account exists (returns true) or not (returns false) +// AccountExists checks whether account exists (returns true) or not (returns false) func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) { am.mux.Lock() defer am.mux.Unlock() @@ -377,12 +422,10 @@ func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) // AddAccount generates a new Account with a provided accountId and userId, saves to the Store func (am *DefaultAccountManager) AddAccount(accountId, userId, domain string) (*Account, error) { - am.mux.Lock() defer am.mux.Unlock() return am.createAccount(accountId, userId, domain) - } func (am *DefaultAccountManager) createAccount(accountId, userId, domain string) (*Account, error) { @@ -398,7 +441,6 @@ func (am *DefaultAccountManager) createAccount(accountId, userId, domain string) // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id func newAccountWithId(accountId, userId, domain string) *Account { - log.Debugf("creating new account") setupKeys := make(map[string]*SetupKey) diff --git a/management/server/file_store.go b/management/server/file_store.go index 37277854e..4f89a32ca 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -28,8 +28,7 @@ type FileStore struct { storeFile string `json:"-"` } -type StoredAccount struct { -} +type StoredAccount struct{} // NewStore restores a store from the file located in the datadir func NewStore(dataDir string) (*FileStore, error) { @@ -39,7 +38,6 @@ func NewStore(dataDir string) (*FileStore, error) { // restore restores the state of the store from the file. // Creates a new empty store file if doesn't exist func restore(file string) (*FileStore, error) { - if _, err := os.Stat(file); os.IsNotExist(err) { // create a new FileStore if previously didn't exist (e.g. first run) s := &FileStore{ @@ -109,12 +107,7 @@ func (s *FileStore) SavePeer(accountId string, peer *Peer) error { } account.Peers[peer.Key] = peer - err = s.persist(s.storeFile) - if err != nil { - return err - } - - return nil + return s.persist(s.storeFile) } // DeletePeer deletes peer from the Store @@ -140,7 +133,7 @@ func (s *FileStore) DeletePeer(accountId string, peerKey string) (*Peer, error) return nil, err } - return peer, err + return peer, nil } // GetPeer returns a peer from a Store @@ -191,16 +184,10 @@ func (s *FileStore) SaveAccount(account *Account) error { s.PrivateDomain2AccountId[account.Domain] = account.Id } - err := s.persist(s.storeFile) - if err != nil { - return err - } - - return nil + return s.persist(s.storeFile) } func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) { - accountId, accountIdFound := s.PrivateDomain2AccountId[strings.ToLower(domain)] if !accountIdFound { return nil, status.Errorf(codes.NotFound, "provided domain is not registered or is not private") @@ -215,7 +202,6 @@ func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) { } func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) { - accountId, accountIdFound := s.SetupKeyId2AccountId[strings.ToUpper(setupKey)] if !accountIdFound { return nil, status.Errorf(codes.NotFound, "provided setup key doesn't exists") @@ -228,6 +214,7 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) { return account, nil } + func (s *FileStore) GetAccountPeers(accountId string) ([]*Peer, error) { s.mux.Lock() defer s.mux.Unlock() @@ -246,7 +233,6 @@ func (s *FileStore) GetAccountPeers(accountId string) ([]*Peer, error) { } func (s *FileStore) GetAccount(accountId string) (*Account, error) { - account, accountFound := s.Accounts[accountId] if !accountFound { return nil, status.Errorf(codes.NotFound, "account not found") @@ -278,3 +264,19 @@ func (s *FileStore) GetPeerAccount(peerKey string) (*Account, error) { return s.GetAccount(accountId) } + +func (s *FileStore) GetGroup(groupID string) (*Group, error) { + return nil, nil +} + +func (s *FileStore) SaveGroup(group *Group) error { + return nil +} + +func (s *FileStore) DeleteGroup(groupID string) error { + return nil +} + +func (s *FileStore) ListGroups() ([]*Group, error) { + return nil, nil +} diff --git a/management/server/group.go b/management/server/group.go new file mode 100644 index 000000000..807acbd4a --- /dev/null +++ b/management/server/group.go @@ -0,0 +1,163 @@ +package server + +import ( + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// Group of the peers for ACL +type Group struct { + // ID of the group + ID string + + // Name visible in the UI + Name string + + // Peers list of the group + Peers []string +} + +// GetGroup object of the peers +func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) { + am.mux.Lock() + defer am.mux.Unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, status.Errorf(codes.NotFound, "account not found") + } + + group, ok := account.Groups[groupID] + if ok { + return group, nil + } + + return nil, status.Errorf(codes.NotFound, "group with ID %s not found", groupID) +} + +// SaveGroup object of the peers +func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error { + am.mux.Lock() + defer am.mux.Unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return status.Errorf(codes.NotFound, "account not found") + } + + account.Groups[group.ID] = group + return am.Store.SaveAccount(account) +} + +// DeleteGroup object of the peers +func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error { + am.mux.Lock() + defer am.mux.Unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return status.Errorf(codes.NotFound, "account not found") + } + + delete(account.Groups, groupID) + + return am.Store.SaveAccount(account) +} + +// ListGroups objects of the peers +func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) { + am.mux.Lock() + defer am.mux.Unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, status.Errorf(codes.NotFound, "account not found") + } + + groups := make([]*Group, 0, len(account.Groups)) + for _, item := range account.Groups { + groups = append(groups, item) + } + + return groups, nil +} + +// GroupAddPeer appends peer to the group +func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string) error { + am.mux.Lock() + defer am.mux.Unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return status.Errorf(codes.NotFound, "account not found") + } + + group, ok := account.Groups[groupID] + if !ok { + return status.Errorf(codes.NotFound, "group with ID %s not found", groupID) + } + + add := true + for _, itemID := range group.Peers { + if itemID == peerKey { + add = false + break + } + } + if add { + group.Peers = append(group.Peers, peerKey) + } + + return am.Store.SaveAccount(account) +} + +// GroupDeletePeer removes peer from the group +func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey string) error { + am.mux.Lock() + defer am.mux.Unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return status.Errorf(codes.NotFound, "account not found") + } + + group, ok := account.Groups[groupID] + if !ok { + return status.Errorf(codes.NotFound, "group with ID %s not found", groupID) + } + + for i, itemID := range group.Peers { + if itemID == peerKey { + group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) + return am.Store.SaveAccount(account) + } + } + + return nil +} + +// GroupListPeers returns list of the peers from the group +func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*Peer, error) { + am.mux.Lock() + defer am.mux.Unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, status.Errorf(codes.NotFound, "account not found") + } + + group, ok := account.Groups[groupID] + if !ok { + return nil, status.Errorf(codes.NotFound, "group with ID %s not found", groupID) + } + + peers := make([]*Peer, 0, len(account.Groups)) + for _, peerID := range group.Peers { + p, ok := account.Peers[peerID] + if ok { + peers = append(peers, p) + } + } + + return peers, nil +} diff --git a/management/server/http/handler/groups.go b/management/server/http/handler/groups.go new file mode 100644 index 000000000..4bab4e7a8 --- /dev/null +++ b/management/server/http/handler/groups.go @@ -0,0 +1,135 @@ +package handler + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/rs/xid" + + "github.com/gorilla/mux" + log "github.com/sirupsen/logrus" +) + +// Groups is a handler that returns groups of the account +type Groups struct { + accountManager server.AccountManager + authAudience string + jwtExtractor jwtclaims.ClaimsExtractor +} + +// GroupResponse is a response sent to the client +type GroupResponse struct { + ID string + Name string + Peers []string +} + +func NewGroups(accountManager server.AccountManager, authAudience string) *Groups { + return &Groups{ + accountManager: accountManager, + authAudience: authAudience, + jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), + } +} + +// GetAllGroupsHandler list for the account +func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) { + account, err := h.getGroupAccount(r) + if err != nil { + log.Error(err) + http.Redirect(w, r, "/", http.StatusInternalServerError) + return + } + + writeJSONObject(w, account.Groups) +} + +func (h *Groups) CreateOrUpdateGroupHandler(w http.ResponseWriter, r *http.Request) { + account, err := h.getGroupAccount(r) + if err != nil { + http.Redirect(w, r, "/", http.StatusInternalServerError) + return + } + + var req server.Group + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if r.Method == http.MethodPost { + req.ID = xid.New().String() + } + + if err := h.accountManager.SaveGroup(account.Id, &req); err != nil { + log.Errorf("failed updating group %s under account %s %v", req.ID, account.Id, err) + http.Redirect(w, r, "/", http.StatusInternalServerError) + return + } + + writeJSONObject(w, &req) +} + +func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) { + account, err := h.getGroupAccount(r) + if err != nil { + http.Redirect(w, r, "/", http.StatusInternalServerError) + return + } + aID := account.Id + + gID := mux.Vars(r)["id"] + if len(gID) == 0 { + http.Error(w, "invalid group ID", http.StatusBadRequest) + return + } + + if err := h.accountManager.DeleteGroup(aID, gID); err != nil { + log.Errorf("failed delete group %s under account %s %v", gID, aID, err) + http.Redirect(w, r, "/", http.StatusInternalServerError) + return + } + + writeJSONObject(w, "") +} + +func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) { + account, err := h.getGroupAccount(r) + if err != nil { + http.Redirect(w, r, "/", http.StatusInternalServerError) + return + } + + switch r.Method { + case http.MethodGet: + groupID := mux.Vars(r)["id"] + if len(groupID) == 0 { + http.Error(w, "invalid group ID", http.StatusBadRequest) + return + } + + group, err := h.accountManager.GetGroup(account.Id, groupID) + if err != nil { + http.Error(w, "group not found", http.StatusNotFound) + return + } + + writeJSONObject(w, group) + default: + http.Error(w, "", http.StatusNotFound) + } +} + +func (h *Groups) getGroupAccount(r *http.Request) (*server.Account, error) { + jwtClaims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + + account, err := h.accountManager.GetAccountWithAuthorizationClaims(jwtClaims) + if err != nil { + return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err) + } + + return account, nil +} diff --git a/management/server/http/handler/groups_test.go b/management/server/http/handler/groups_test.go new file mode 100644 index 000000000..ed8d46568 --- /dev/null +++ b/management/server/http/handler/groups_test.go @@ -0,0 +1,202 @@ +package handler + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server/jwtclaims" + + "github.com/magiconair/properties/assert" + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/mock_server" +) + +func initGroupTestData(groups ...*server.Group) *Groups { + return &Groups{ + accountManager: &mock_server.MockAccountManager{ + SaveGroupFunc: func(accountID string, group *server.Group) error { + if !strings.HasPrefix(group.ID, "id-") { + group.ID = "id-was-set" + } + return nil + }, + GetGroupFunc: func(_, groupID string) (*server.Group, error) { + if groupID != "idofthegroup" { + return nil, fmt.Errorf("not found") + } + return &server.Group{ + ID: "idofthegroup", + Name: "Group", + }, nil + }, + GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { + return &server.Account{ + Id: claims.AccountId, + Domain: "hotmail.com", + }, nil + }, + }, + authAudience: "", + jwtExtractor: jwtclaims.ClaimsExtractor{ + ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims { + return jwtclaims.AuthorizationClaims{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: "test_id", + } + }, + }, + } +} + +func TestGetGroup(t *testing.T) { + tt := []struct { + name string + expectedStatus int + expectedBody bool + requestType string + requestPath string + requestBody io.Reader + }{ + { + name: "GetGroup OK", + expectedBody: true, + requestType: http.MethodGet, + requestPath: "/api/groups/idofthegroup", + expectedStatus: http.StatusOK, + }, + { + name: "GetGroup not found", + requestType: http.MethodGet, + requestPath: "/api/groups/notexists", + expectedStatus: http.StatusNotFound, + }, + } + + group := &server.Group{ + ID: "idofthegroup", + Name: "Group", + } + + p := initGroupTestData(group) + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + + router := mux.NewRouter() + router.HandleFunc("/api/groups/{id}", p.GetGroupHandler).Methods("GET") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + if status := recorder.Code; status != tc.expectedStatus { + t.Errorf("handler returned wrong status code: got %v want %v", + status, tc.expectedStatus) + return + } + + if !tc.expectedBody { + return + } + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("I don't know what I expected; %v", err) + } + + got := &server.Group{} + if err = json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, got.ID, group.ID) + assert.Equal(t, got.Name, group.Name) + }) + } +} + +func TestSaveGroup(t *testing.T) { + tt := []struct { + name string + expectedStatus int + expectedBody bool + expectedGroup *server.Group + requestType string + requestPath string + requestBody io.Reader + }{ + { + name: "SaveGroup POST OK", + requestType: http.MethodPost, + requestPath: "/api/groups", + requestBody: bytes.NewBuffer( + []byte(`{"Name":"Default POSTed Group"}`)), + expectedStatus: http.StatusOK, + expectedBody: true, + expectedGroup: &server.Group{ + ID: "id-was-set", + Name: "Default POSTed Group", + }, + }, + { + name: "SaveGroup PUT OK", + requestType: http.MethodPut, + requestPath: "/api/groups", + requestBody: bytes.NewBuffer( + []byte(`{"ID":"id-existed","Name":"Default POSTed Group"}`)), + expectedStatus: http.StatusOK, + expectedGroup: &server.Group{ + ID: "id-existed", + Name: "Default POSTed Group", + }, + }, + } + + p := initGroupTestData() + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + + router := mux.NewRouter() + router.HandleFunc("/api/groups", p.CreateOrUpdateGroupHandler).Methods("PUT", "POST") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("I don't know what I expected; %v", err) + } + + if status := recorder.Code; status != tc.expectedStatus { + t.Errorf("handler returned wrong status code: got %v want %v, content: %s", + status, tc.expectedStatus, string(content)) + return + } + + if !tc.expectedBody { + return + } + + got := &server.Group{} + if err = json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.Equal(t, got, tc.expectedGroup) + }) + } +} diff --git a/management/server/http/handler/peers_test.go b/management/server/http/handler/peers_test.go index 9967c6551..c3ade8ff8 100644 --- a/management/server/http/handler/peers_test.go +++ b/management/server/http/handler/peers_test.go @@ -2,13 +2,14 @@ package handler import ( "encoding/json" - "github.com/netbirdio/netbird/management/server/jwtclaims" "io" "net" "net/http" "net/http/httptest" "testing" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/magiconair/properties/assert" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/mock_server" @@ -43,14 +44,19 @@ func initTestMetaData(peer ...*server.Peer) *Peers { // Tests the GetPeers endpoint reachable in the route /api/peers // Use the metadata generated by initTestMetaData() to check for values func TestGetPeers(t *testing.T) { - var tt = []struct { + tt := []struct { name string expectedStatus int requestType string requestPath string requestBody io.Reader }{ - {name: "GetPeersMetaData", requestType: http.MethodGet, requestPath: "/api/peers/", expectedStatus: http.StatusOK}, + { + name: "GetPeersMetaData", + requestType: http.MethodGet, + requestPath: "/api/peers/", + expectedStatus: http.StatusOK, + }, } rr := httptest.NewRecorder() diff --git a/management/server/http/server.go b/management/server/http/server.go index 1be916eff..83030847f 100644 --- a/management/server/http/server.go +++ b/management/server/http/server.go @@ -25,26 +25,44 @@ type Server struct { // NewHttpsServer creates a new HTTPs server (with HTTPS support) and a certManager that is responsible for generating and renewing Let's Encrypt certificate // The listening address will be :443 no matter what was specified in s.HttpServerConfig.Address -func NewHttpsServer(config *s.HttpServerConfig, certManager *autocert.Manager, accountManager s.AccountManager) *Server { +func NewHttpsServer( + config *s.HttpServerConfig, + certManager *autocert.Manager, + accountManager s.AccountManager, +) *Server { server := &http.Server{ Addr: config.Address, WriteTimeout: time.Second * 15, ReadTimeout: time.Second * 15, IdleTimeout: time.Second * 60, } - return &Server{server: server, config: config, certManager: certManager, accountManager: accountManager} + return &Server{ + server: server, + config: config, + certManager: certManager, + accountManager: accountManager, + } } // NewHttpsServerWithTLSConfig creates a new HTTPs server with a provided tls.Config. // Usually used when you already have a certificate -func NewHttpsServerWithTLSConfig(config *s.HttpServerConfig, tlsConfig *tls.Config, accountManager s.AccountManager) *Server { +func NewHttpsServerWithTLSConfig( + config *s.HttpServerConfig, + tlsConfig *tls.Config, + accountManager s.AccountManager, +) *Server { server := &http.Server{ Addr: config.Address, WriteTimeout: time.Second * 15, ReadTimeout: time.Second * 15, IdleTimeout: time.Second * 60, } - return &Server{server: server, config: config, tlsConfig: tlsConfig, accountManager: accountManager} + return &Server{ + server: server, + config: config, + tlsConfig: tlsConfig, + accountManager: accountManager, + } } // NewHttpServer creates a new HTTP server (without HTTPS) @@ -63,8 +81,11 @@ func (s *Server) Stop(ctx context.Context) error { // Start defines http handlers and starts the http server. Blocks until server is shutdown. func (s *Server) Start() error { - - jwtMiddleware, err := middleware.NewJwtMiddleware(s.config.AuthIssuer, s.config.AuthAudience, s.config.AuthKeysLocation) + jwtMiddleware, err := middleware.NewJwtMiddleware( + s.config.AuthIssuer, + s.config.AuthAudience, + s.config.AuthKeysLocation, + ) if err != nil { return err } @@ -74,19 +95,31 @@ func (s *Server) Start() error { r := mux.NewRouter() r.Use(jwtMiddleware.Handler, corsMiddleware.Handler) + groupsHandler := handler.NewGroups(s.accountManager, s.config.AuthAudience) peersHandler := handler.NewPeers(s.accountManager, s.config.AuthAudience) keysHandler := handler.NewSetupKeysHandler(s.accountManager, s.config.AuthAudience) r.HandleFunc("/api/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS") - r.HandleFunc("/api/peers/{id}", peersHandler.HandlePeer).Methods("GET", "PUT", "DELETE", "OPTIONS") + r.HandleFunc("/api/peers/{id}", peersHandler.HandlePeer). + Methods("GET", "PUT", "DELETE", "OPTIONS") - r.HandleFunc("/api/setup-keys", keysHandler.GetKeys).Methods("GET", "POST", "OPTIONS") - r.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).Methods("GET", "PUT", "OPTIONS") + r.HandleFunc("/api/setup-keys", keysHandler.GetKeys).Methods("POST", "OPTIONS") + r.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey). + Methods("GET", "PUT", "DELETE", "OPTIONS") + + r.HandleFunc("/api/groups", groupsHandler.GetAllGroupsHandler).Methods("GET", "OPTIONS") + r.HandleFunc("/api/groups", groupsHandler.CreateOrUpdateGroupHandler). + Methods("POST", "PUT", "OPTIONS") + r.HandleFunc("/api/groups/{id}", groupsHandler.GetGroupHandler).Methods("GET", "OPTIONS") + r.HandleFunc("/api/groups/{id}", groupsHandler.DeleteGroupHandler).Methods("DELETE", "OPTIONS") http.Handle("/", r) if s.certManager != nil { // if HTTPS is enabled we reuse the listener from the cert manager listener := s.certManager.Listener() - log.Infof("HTTPs server listening on %s with Let's Encrypt autocert configured", listener.Addr()) + log.Infof( + "HTTPs server listening on %s with Let's Encrypt autocert configured", + listener.Addr(), + ) if err = http.Serve(listener, s.certManager.HTTPHandler(r)); err != nil { log.Errorf("failed to serve https server: %v", err) return err diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 8a727896f..b7aabe3c2 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -26,13 +26,25 @@ type MockAccountManager struct { GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error) GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) AddPeerFunc func(setupKey string, peer *server.Peer) (*server.Peer, error) + GetGroupFunc func(accountID, groupID string) (*server.Group, error) + SaveGroupFunc func(accountID string, group *server.Group) error + DeleteGroupFunc func(accountID, groupID string) error + ListGroupsFunc func(accountID string) ([]*server.Group, error) + GroupAddPeerFunc func(accountID, groupID, peerKey string) error + GroupDeletePeerFunc func(accountID, groupID, peerKey string) error + GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error) } -func (am *MockAccountManager) GetOrCreateAccountByUser(userId, domain string) (*server.Account, error) { +func (am *MockAccountManager) GetOrCreateAccountByUser( + userId, domain string, +) (*server.Account, error) { if am.GetOrCreateAccountByUserFunc != nil { return am.GetOrCreateAccountByUserFunc(userId, domain) } - return nil, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByUser not implemented") + return nil, status.Errorf( + codes.Unimplemented, + "method GetOrCreateAccountByUser not implemented", + ) } func (am *MockAccountManager) GetAccountByUser(userId string) (*server.Account, error) { @@ -42,21 +54,33 @@ func (am *MockAccountManager) GetAccountByUser(userId string) (*server.Account, return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUser not implemented") } -func (am *MockAccountManager) AddSetupKey(accountId string, keyName string, keyType server.SetupKeyType, expiresIn *util.Duration) (*server.SetupKey, error) { +func (am *MockAccountManager) AddSetupKey( + accountId string, + keyName string, + keyType server.SetupKeyType, + expiresIn *util.Duration, +) (*server.SetupKey, error) { if am.AddSetupKeyFunc != nil { return am.AddSetupKeyFunc(accountId, keyName, keyType, expiresIn) } return nil, status.Errorf(codes.Unimplemented, "method AddSetupKey not implemented") } -func (am *MockAccountManager) RevokeSetupKey(accountId string, keyId string) (*server.SetupKey, error) { +func (am *MockAccountManager) RevokeSetupKey( + accountId string, + keyId string, +) (*server.SetupKey, error) { if am.RevokeSetupKeyFunc != nil { return am.RevokeSetupKeyFunc(accountId, keyId) } return nil, status.Errorf(codes.Unimplemented, "method RevokeSetupKey not implemented") } -func (am *MockAccountManager) RenameSetupKey(accountId string, keyId string, newName string) (*server.SetupKey, error) { +func (am *MockAccountManager) RenameSetupKey( + accountId string, + keyId string, + newName string, +) (*server.SetupKey, error) { if am.RenameSetupKeyFunc != nil { return am.RenameSetupKeyFunc(accountId, keyId, newName) } @@ -70,18 +94,28 @@ func (am *MockAccountManager) GetAccountById(accountId string) (*server.Account, return nil, status.Errorf(codes.Unimplemented, "method GetAccountById not implemented") } -func (am *MockAccountManager) GetAccountByUserOrAccountId(userId, accountId, domain string) (*server.Account, error) { +func (am *MockAccountManager) GetAccountByUserOrAccountId( + userId, accountId, domain string, +) (*server.Account, error) { if am.GetAccountByUserOrAccountIdFunc != nil { return am.GetAccountByUserOrAccountIdFunc(userId, accountId, domain) } - return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUserOrAccountId not implemented") + return nil, status.Errorf( + codes.Unimplemented, + "method GetAccountByUserOrAccountId not implemented", + ) } -func (am *MockAccountManager) GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { +func (am *MockAccountManager) GetAccountWithAuthorizationClaims( + claims jwtclaims.AuthorizationClaims, +) (*server.Account, error) { if am.GetAccountWithAuthorizationClaimsFunc != nil { return am.GetAccountWithAuthorizationClaimsFunc(claims) } - return nil, status.Errorf(codes.Unimplemented, "method GetAccountWithAuthorizationClaims not implemented") + return nil, status.Errorf( + codes.Unimplemented, + "method GetAccountWithAuthorizationClaims not implemented", + ) } func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) { @@ -91,7 +125,9 @@ func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) { return nil, status.Errorf(codes.Unimplemented, "method AccountExists not implemented") } -func (am *MockAccountManager) AddAccount(accountId, userId, domain string) (*server.Account, error) { +func (am *MockAccountManager) AddAccount( + accountId, userId, domain string, +) (*server.Account, error) { if am.AddAccountFunc != nil { return am.AddAccountFunc(accountId, userId, domain) } @@ -112,7 +148,11 @@ func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool) return status.Errorf(codes.Unimplemented, "method MarkPeerConnected not implemented") } -func (am *MockAccountManager) RenamePeer(accountId string, peerKey string, newName string) (*server.Peer, error) { +func (am *MockAccountManager) RenamePeer( + accountId string, + peerKey string, + newName string, +) (*server.Peer, error) { if am.RenamePeerFunc != nil { return am.RenamePeerFunc(accountId, peerKey, newName) } @@ -146,3 +186,52 @@ func (am *MockAccountManager) AddPeer(setupKey string, peer *server.Peer) (*serv } return nil, status.Errorf(codes.Unimplemented, "method AddPeer not implemented") } + +func (am *MockAccountManager) GetGroup(accountID, groupID string) (*server.Group, error) { + if am.GetGroupFunc != nil { + return am.GetGroupFunc(accountID, groupID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetGroup not implemented") +} + +func (am *MockAccountManager) SaveGroup(accountID string, group *server.Group) error { + if am.SaveGroupFunc != nil { + return am.SaveGroupFunc(accountID, group) + } + return status.Errorf(codes.Unimplemented, "method UpdateGroup not implemented") +} + +func (am *MockAccountManager) DeleteGroup(accountID, groupID string) error { + if am.DeleteGroupFunc != nil { + return am.DeleteGroupFunc(accountID, groupID) + } + return status.Errorf(codes.Unimplemented, "method DeleteGroup not implemented") +} + +func (am *MockAccountManager) ListGroups(accountID string) ([]*server.Group, error) { + if am.ListGroupsFunc != nil { + return am.ListGroupsFunc(accountID) + } + return nil, status.Errorf(codes.Unimplemented, "method ListGroups not implemented") +} + +func (am *MockAccountManager) GroupAddPeer(accountID, groupID, peerKey string) error { + if am.GroupAddPeerFunc != nil { + return am.GroupAddPeerFunc(accountID, groupID, peerKey) + } + return status.Errorf(codes.Unimplemented, "method GroupAddPeer not implemented") +} + +func (am *MockAccountManager) GroupDeletePeer(accountID, groupID, peerKey string) error { + if am.GroupDeletePeerFunc != nil { + return am.GroupDeletePeerFunc(accountID, groupID, peerKey) + } + return status.Errorf(codes.Unimplemented, "method GroupDeletePeer not implemented") +} + +func (am *MockAccountManager) GroupListPeers(accountID, groupID string) ([]*server.Peer, error) { + if am.GroupListPeersFunc != nil { + return am.GroupListPeersFunc(accountID, groupID) + } + return nil, status.Errorf(codes.Unimplemented, "method GroupListPeers not implemented") +}