diff --git a/management/server/account.go b/management/server/account.go index 25ee756ad..ecc7f2260 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -34,6 +34,8 @@ const ( PublicCategory = "public" PrivateCategory = "private" UnknownCategory = "unknown" + GroupIssuedAPI = "api" + GroupIssuedJWT = "jwt" CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days DefaultPeerLoginExpiration = 24 * time.Hour @@ -139,6 +141,13 @@ type Settings struct { // PeerLoginExpiration is a setting that indicates when peer login expires. // Applies to all peers that have Peer.LoginExpirationEnabled set to true. PeerLoginExpiration time.Duration + + // JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName + // and add it to account groups. + JWTGroupsEnabled bool + + // JWTGroupsClaimName from which we extract groups name to add it to account groups + JWTGroupsClaimName string } // Copy copies the Settings struct @@ -146,6 +155,8 @@ func (s *Settings) Copy() *Settings { return &Settings{ PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled, PeerLoginExpiration: s.PeerLoginExpiration, + JWTGroupsEnabled: s.JWTGroupsEnabled, + JWTGroupsClaimName: s.JWTGroupsClaimName, } } @@ -612,6 +623,28 @@ func (a *Account) GetPeer(peerID string) *Peer { return a.Peers[peerID] } +// AddJWTGroups to existed groups if they does not exists +func (a *Account) AddJWTGroups(groups []string) (int, error) { + existedGroups := make(map[string]*Group) + for _, g := range a.Groups { + existedGroups[g.Name] = g + } + + var count int + for _, name := range groups { + if _, ok := existedGroups[name]; !ok { + id := xid.New().String() + a.Groups[id] = &Group{ + ID: id, + Name: name, + Issued: GroupIssuedJWT, + } + count++ + } + } + return count, nil +} + // BuildManager creates a new DefaultAccountManager with a provided Store func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, @@ -1241,6 +1274,38 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat } } + if account.Settings.JWTGroupsEnabled { + if account.Settings.JWTGroupsClaimName == "" { + log.Errorf("JWT groups are enabled but no claim name is set") + return account, user, nil + } + if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok { + if slice, ok := claim.([]interface{}); ok { + var groups []string + for _, item := range slice { + if g, ok := item.(string); ok { + groups = append(groups, g) + } else { + log.Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item) + } + } + n, err := account.AddJWTGroups(groups) + if err != nil { + log.Errorf("failed to add JWT groups: %v", err) + } + if n > 0 { + if err := am.Store.SaveAccount(account); err != nil { + log.Errorf("failed to save account: %v", err) + } + } + } else { + log.Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName) + } + } else { + log.Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName) + } + } + return account, user, nil } @@ -1344,8 +1409,9 @@ func (am *DefaultAccountManager) GetDNSDomain() string { func addAllGroup(account *Account) error { if len(account.Groups) == 0 { allGroup := &Group{ - ID: xid.New().String(), - Name: "All", + ID: xid.New().String(), + Name: "All", + Issued: GroupIssuedAPI, } for _, peer := range account.Peers { allGroup.Peers = append(allGroup.Peers, peer.ID) diff --git a/management/server/account_test.go b/management/server/account_test.go index 495e93892..55aa0d1d0 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/route" @@ -460,6 +461,69 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { } } +func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { + userId := "user-id" + domain := "test.domain" + + initAccount := newAccountWithId("", userId, domain) + manager, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + + accountID := initAccount.Id + _, err = manager.GetAccountByUserOrAccountID(userId, accountID, domain) + require.NoError(t, err, "create init user failed") + + claims := jwtclaims.AuthorizationClaims{ + AccountId: accountID, + Domain: domain, + UserId: userId, + DomainCategory: "test-category", + Raw: jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}}, + } + + t.Run("JWT groups disabled", func(t *testing.T) { + account, _, err := manager.GetAccountFromToken(claims) + require.NoError(t, err, "get account by token failed") + require.Len(t, account.Groups, 1, "only ALL group should exists") + }) + + t.Run("JWT groups enabled without claim name", func(t *testing.T) { + initAccount.Settings.JWTGroupsEnabled = true + err := manager.Store.SaveAccount(initAccount) + require.NoError(t, err, "save account failed") + + account, _, err := manager.GetAccountFromToken(claims) + require.NoError(t, err, "get account by token failed") + require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") + }) + + t.Run("JWT groups enabled", func(t *testing.T) { + initAccount.Settings.JWTGroupsEnabled = true + initAccount.Settings.JWTGroupsClaimName = "idp-groups" + err := manager.Store.SaveAccount(initAccount) + require.NoError(t, err, "save account failed") + + account, _, err := manager.GetAccountFromToken(claims) + require.NoError(t, err, "get account by token failed") + require.Len(t, account.Groups, 3, "groups should be added to the account") + + groupsByNames := map[string]*Group{} + for _, g := range account.Groups { + groupsByNames[g.Name] = g + } + + g1, ok := groupsByNames["group1"] + require.True(t, ok, "group1 should be added to the account") + require.Equal(t, g1.Name, "group1", "group1 name should match") + require.Equal(t, g1.Issued, GroupIssuedJWT, "group1 issued should match") + + g2, ok := groupsByNames["group2"] + require.True(t, ok, "group2 should be added to the account") + require.Equal(t, g2.Name, "group2", "group2 name should match") + require.Equal(t, g2.Issued, GroupIssuedJWT, "group2 issued should match") + }) +} + func TestAccountManager_GetAccountFromPAT(t *testing.T) { store := newStore(t) account := newAccountWithId("account_id", "testuser", "") diff --git a/management/server/file_store.go b/management/server/file_store.go index c39af154a..4bbe95a10 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -157,6 +157,14 @@ func restore(file string) (*FileStore, error) { addPeerLabelsToAccount(account, existingLabels) } + // TODO: delete this block after migration + // Set API as issuer for groups which has not this field + for _, group := range account.Groups { + if group.Issued == "" { + group.Issued = GroupIssuedAPI + } + } + allGroup, err := account.GetGroupAll() if err != nil { log.Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err) diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go index ce035ea83..e2f07acda 100644 --- a/management/server/file_store_test.go +++ b/management/server/file_store_test.go @@ -262,6 +262,7 @@ func TestRestore(t *testing.T) { require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length") } +// TODO: outdated, delete this func TestRestorePolicies_Migration(t *testing.T) { storeDir := t.TempDir() @@ -296,6 +297,40 @@ func TestRestorePolicies_Migration(t *testing.T) { "failed to restore a FileStore file - missing Account Policies Sources") } +func TestRestoreGroups_Migration(t *testing.T) { + storeDir := t.TempDir() + + err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) + if err != nil { + t.Fatal(err) + } + + store, err := NewFileStore(storeDir, nil) + if err != nil { + return + } + + // create default group + account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] + account.Groups = map[string]*Group{ + "cfefqs706sqkneg59g3g": { + ID: "cfefqs706sqkneg59g3g", + Name: "All", + }, + } + err = store.SaveAccount(account) + require.NoError(t, err, "failed to save account") + + // restore account with default group with empty Issue field + if store, err = NewFileStore(storeDir, nil); err != nil { + return + } + account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] + + require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups") + require.Equal(t, GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark") +} + func TestGetAccountByPrivateDomain(t *testing.T) { storeDir := t.TempDir() diff --git a/management/server/group.go b/management/server/group.go index 688d6dafa..dd1229c86 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -14,6 +14,9 @@ type Group struct { // Name visible in the UI Name string + // Issued of the group + Issued string + // Peers list of the group Peers []string } @@ -45,9 +48,10 @@ func (g *Group) EventMeta() map[string]any { func (g *Group) Copy() *Group { return &Group{ - ID: g.ID, - Name: g.Name, - Peers: g.Peers[:], + ID: g.ID, + Name: g.Name, + Issued: g.Issued, + Peers: g.Peers[:], } } diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index 13d5909ce..d94bea3fb 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -72,10 +72,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) return } - updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, &server.Settings{ + settings := &server.Settings{ PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), - }) + } + + if req.Settings.JwtGroupsEnabled != nil { + settings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled + } + if req.Settings.JwtGroupsClaimName != nil { + settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName + } + + updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings) if err != nil { util.WriteError(err, w) @@ -93,6 +102,8 @@ func toAccountResponse(account *server.Account) *api.Account { Settings: api.AccountSettings{ PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()), PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled, + JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled, + JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName, }, } } diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/accounts_handler_test.go index 405dd94f4..5051f45e1 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/accounts_handler_test.go @@ -58,6 +58,9 @@ func TestAccounts_AccountsHandler(t *testing.T) { accountID := "test_account" adminUser := server.NewAdminUser("test_user") + sr := func(v string) *string { return &v } + br := func(v bool) *bool { return &v } + handler := initAccountsTestData(&server.Account{ Id: accountID, Domain: "hotmail.com", @@ -91,6 +94,8 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedSettings: api.AccountSettings{ PeerLoginExpiration: int(time.Hour.Seconds()), PeerLoginExpirationEnabled: false, + JwtGroupsClaimName: sr(""), + JwtGroupsEnabled: br(false), }, expectedArray: true, expectedID: accountID, @@ -105,6 +110,24 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedSettings: api.AccountSettings{ PeerLoginExpiration: 15552000, PeerLoginExpirationEnabled: true, + JwtGroupsClaimName: sr(""), + JwtGroupsEnabled: br(false), + }, + expectedArray: false, + expectedID: accountID, + }, + { + name: "PutAccount OK wiht JWT", + expectedBody: true, + requestType: http.MethodPut, + requestPath: "/api/accounts/" + accountID, + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\"}}"), + expectedStatus: http.StatusOK, + expectedSettings: api.AccountSettings{ + PeerLoginExpiration: 15552000, + PeerLoginExpirationEnabled: false, + JwtGroupsClaimName: sr("roles"), + JwtGroupsEnabled: br(true), }, expectedArray: false, expectedID: accountID, diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 5a22ac1f9..b795b4608 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -54,6 +54,14 @@ components: description: Period of time after which peer login expires (seconds). type: integer example: 43200 + jwt_groups_enabled: + description: Allows extract groups from JWT claim and add it to account groups. + type: boolean + example: true + jwt_groups_claim_name: + description: Name of the claim from which we extract groups names to add it to account groups. + type: string + example: "roles" required: - peer_login_expiration_enabled - peer_login_expiration @@ -462,6 +470,10 @@ components: description: Count of peers associated to the group type: integer example: 2 + issued: + description: How group was issued by API or from JWT token + type: string + example: api required: - id - name diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 8ef1bc935..ef4a2b682 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -129,6 +129,12 @@ type AccountRequest struct { // AccountSettings defines model for AccountSettings. type AccountSettings struct { + // JwtGroupsClaimName Name of the claim from which we extract groups names to add it to account groups. + JwtGroupsClaimName *string `json:"jwt_groups_claim_name,omitempty"` + + // JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups. + JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"` + // PeerLoginExpiration Period of time after which peer login expires (seconds). PeerLoginExpiration int `json:"peer_login_expiration"` @@ -174,6 +180,9 @@ type Group struct { // Id Group ID Id string `json:"id"` + // Issued How group was issued by API or from JWT token + Issued *string `json:"issued,omitempty"` + // Name Group Name identifier Name string `json:"name"` @@ -189,6 +198,9 @@ type GroupMinimum struct { // Id Group ID Id string `json:"id"` + // Issued How group was issued by API or from JWT token + Issued *string `json:"issued,omitempty"` + // Name Group Name identifier Name string `json:"name"` diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index 4fcf6ce5d..966c3f678 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -72,7 +72,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { return } - _, ok = account.Groups[groupID] + eg, ok := account.Groups[groupID] if !ok { util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w) return @@ -107,9 +107,10 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { peers = *req.Peers } group := server.Group{ - ID: groupID, - Name: req.Name, - Peers: peers, + ID: groupID, + Name: req.Name, + Peers: peers, + Issued: eg.Issued, } if err := h.accountManager.SaveGroup(account.Id, user.Id, &group); err != nil { @@ -149,9 +150,10 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { peers = *req.Peers } group := server.Group{ - ID: xid.New().String(), - Name: req.Name, - Peers: peers, + ID: xid.New().String(), + Name: req.Name, + Peers: peers, + Issued: server.GroupIssuedAPI, } err = h.accountManager.SaveGroup(account.Id, user.Id, &group) @@ -237,6 +239,7 @@ func toGroupResponse(account *server.Account, group *server.Group) *api.Group { Id: group.ID, Name: group.Name, PeersCount: len(group.Peers), + Issued: &group.Issued, } for _, pid := range group.Peers { diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index 754909999..73ca8db6a 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -42,9 +42,17 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle if groupID != "idofthegroup" { return nil, status.Errorf(status.NotFound, "not found") } + if groupID == "id-jwt-group" { + return &server.Group{ + ID: "id-jwt-group", + Name: "Default Group", + Issued: server.GroupIssuedJWT, + }, nil + } return &server.Group{ - ID: "idofthegroup", - Name: "Group", + ID: "idofthegroup", + Name: "Group", + Issued: server.GroupIssuedAPI, }, nil }, UpdateGroupFunc: func(_ string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) { @@ -80,8 +88,9 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle user.Id: user, }, Groups: map[string]*server.Group{ - "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}}, - "id-all": {ID: "id-all", Name: "All"}, + "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: server.GroupIssuedJWT}, + "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: server.GroupIssuedAPI}, + "id-all": {ID: "id-all", Name: "All", Issued: server.GroupIssuedAPI}, }, }, user, nil }, @@ -169,6 +178,8 @@ func TestGetGroup(t *testing.T) { } func TestWriteGroup(t *testing.T) { + groupIssuedAPI := "api" + groupIssuedJWT := "jwt" tt := []struct { name string expectedStatus int @@ -187,8 +198,9 @@ func TestWriteGroup(t *testing.T) { expectedStatus: http.StatusOK, expectedBody: true, expectedGroup: &api.Group{ - Id: "id-was-set", - Name: "Default POSTed Group", + Id: "id-was-set", + Name: "Default POSTed Group", + Issued: &groupIssuedAPI, }, }, { @@ -208,8 +220,9 @@ func TestWriteGroup(t *testing.T) { []byte(`{"Name":"Default POSTed Group"}`)), expectedStatus: http.StatusOK, expectedGroup: &api.Group{ - Id: "id-existed", - Name: "Default POSTed Group", + Id: "id-existed", + Name: "Default POSTed Group", + Issued: &groupIssuedAPI, }, }, { @@ -230,6 +243,19 @@ func TestWriteGroup(t *testing.T) { expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, + { + name: "Write Group PUT not not change Issue", + requestType: http.MethodPut, + requestPath: "/api/groups/id-jwt-group", + requestBody: bytes.NewBuffer( + []byte(`{"Name":"changed","Issued":"api"}`)), + expectedStatus: http.StatusOK, + expectedGroup: &api.Group{ + Id: "id-jwt-group", + Name: "changed", + Issued: &groupIssuedJWT, + }, + }, } adminUser := server.NewAdminUser("test_user") diff --git a/management/server/jwtclaims/claims.go b/management/server/jwtclaims/claims.go index 2d7dc499a..946c0b8be 100644 --- a/management/server/jwtclaims/claims.go +++ b/management/server/jwtclaims/claims.go @@ -1,9 +1,15 @@ package jwtclaims +import ( + "github.com/golang-jwt/jwt" +) + // AuthorizationClaims stores authorization information from JWTs type AuthorizationClaims struct { UserId string AccountId string Domain string DomainCategory string + + Raw jwt.MapClaims } diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index 9aa00a004..466856d77 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -73,7 +73,9 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor { // FromToken extracts claims from the token (after auth) func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims { claims := token.Claims.(jwt.MapClaims) - jwtClaims := AuthorizationClaims{} + jwtClaims := AuthorizationClaims{ + Raw: claims, + } userID, ok := claims[c.userIDClaim].(string) if !ok { return jwtClaims diff --git a/management/server/jwtclaims/extractor_test.go b/management/server/jwtclaims/extractor_test.go index 53f8818b1..9bececac6 100644 --- a/management/server/jwtclaims/extractor_test.go +++ b/management/server/jwtclaims/extractor_test.go @@ -48,6 +48,12 @@ func TestExtractClaimsFromRequestContext(t *testing.T) { Domain: "test.com", AccountId: "testAcc", DomainCategory: "public", + Raw: jwt.MapClaims{ + "https://login/wt_account_domain": "test.com", + "https://login/wt_account_domain_category": "public", + "https://login/wt_account_id": "testAcc", + "sub": "test", + }, }, testingFunc: require.EqualValues, expectedMSG: "extracted claims should match input claims", @@ -59,6 +65,10 @@ func TestExtractClaimsFromRequestContext(t *testing.T) { inputAuthorizationClaims: AuthorizationClaims{ UserId: "test", AccountId: "testAcc", + Raw: jwt.MapClaims{ + "https://login/wt_account_id": "testAcc", + "sub": "test", + }, }, testingFunc: require.EqualValues, expectedMSG: "extracted claims should match input claims", @@ -70,6 +80,10 @@ func TestExtractClaimsFromRequestContext(t *testing.T) { inputAuthorizationClaims: AuthorizationClaims{ UserId: "test", Domain: "test.com", + Raw: jwt.MapClaims{ + "https://login/wt_account_domain": "test.com", + "sub": "test", + }, }, testingFunc: require.EqualValues, expectedMSG: "extracted claims should match input claims", @@ -82,6 +96,11 @@ func TestExtractClaimsFromRequestContext(t *testing.T) { UserId: "test", Domain: "test.com", AccountId: "testAcc", + Raw: jwt.MapClaims{ + "https://login/wt_account_domain": "test.com", + "https://login/wt_account_id": "testAcc", + "sub": "test", + }, }, testingFunc: require.EqualValues, expectedMSG: "extracted claims should match input claims", @@ -92,6 +111,9 @@ func TestExtractClaimsFromRequestContext(t *testing.T) { inputAudiance: "https://login/", inputAuthorizationClaims: AuthorizationClaims{ UserId: "test", + Raw: jwt.MapClaims{ + "sub": "test", + }, }, testingFunc: require.EqualValues, expectedMSG: "extracted claims should match input claims",