mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-04 01:41:17 +01:00
408 lines
13 KiB
Go
408 lines
13 KiB
Go
package auth_test
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/netbirdio/netbird/management/server/auth"
|
|
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
|
"github.com/netbirdio/netbird/management/server/store"
|
|
"github.com/netbirdio/netbird/management/server/types"
|
|
)
|
|
|
|
func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) {
|
|
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
|
if err != nil {
|
|
t.Fatalf("Error when creating store: %s", err)
|
|
}
|
|
t.Cleanup(cleanup)
|
|
|
|
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
|
hashedToken := sha256.Sum256([]byte(token))
|
|
encodedHashedToken := base64.StdEncoding.EncodeToString(hashedToken[:])
|
|
account := &types.Account{
|
|
Id: "account_id",
|
|
Users: map[string]*types.User{"someUser": {
|
|
Id: "someUser",
|
|
PATs: map[string]*types.PersonalAccessToken{
|
|
"tokenId": {
|
|
ID: "tokenId",
|
|
UserID: "someUser",
|
|
HashedToken: encodedHashedToken,
|
|
},
|
|
},
|
|
}},
|
|
}
|
|
|
|
err = store.SaveAccount(context.Background(), account)
|
|
if err != nil {
|
|
t.Fatalf("Error when saving account: %s", err)
|
|
}
|
|
|
|
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
|
|
|
user, pat, _, _, err := manager.GetPATInfo(context.Background(), token)
|
|
if err != nil {
|
|
t.Fatalf("Error when getting Account from PAT: %s", err)
|
|
}
|
|
|
|
assert.Equal(t, "account_id", user.AccountID)
|
|
assert.Equal(t, "someUser", user.Id)
|
|
assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID)
|
|
}
|
|
|
|
func TestAuthManager_MarkPATUsed(t *testing.T) {
|
|
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
|
if err != nil {
|
|
t.Fatalf("Error when creating store: %s", err)
|
|
}
|
|
t.Cleanup(cleanup)
|
|
|
|
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
|
hashedToken := sha256.Sum256([]byte(token))
|
|
encodedHashedToken := base64.StdEncoding.EncodeToString(hashedToken[:])
|
|
account := &types.Account{
|
|
Id: "account_id",
|
|
Users: map[string]*types.User{"someUser": {
|
|
Id: "someUser",
|
|
PATs: map[string]*types.PersonalAccessToken{
|
|
"tokenId": {
|
|
ID: "tokenId",
|
|
HashedToken: encodedHashedToken,
|
|
},
|
|
},
|
|
}},
|
|
}
|
|
|
|
err = store.SaveAccount(context.Background(), account)
|
|
if err != nil {
|
|
t.Fatalf("Error when saving account: %s", err)
|
|
}
|
|
|
|
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
|
|
|
err = manager.MarkPATUsed(context.Background(), "tokenId")
|
|
if err != nil {
|
|
t.Fatalf("Error when marking PAT used: %s", err)
|
|
}
|
|
|
|
account, err = store.GetAccount(context.Background(), "account_id")
|
|
if err != nil {
|
|
t.Fatalf("Error when getting account: %s", err)
|
|
}
|
|
assert.True(t, !account.Users["someUser"].PATs["tokenId"].GetLastUsed().IsZero())
|
|
}
|
|
|
|
func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
|
|
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
|
if err != nil {
|
|
t.Fatalf("Error when creating store: %s", err)
|
|
}
|
|
t.Cleanup(cleanup)
|
|
|
|
userId := "user-id"
|
|
domain := "test.domain"
|
|
|
|
account := &types.Account{
|
|
Id: "account_id",
|
|
Domain: domain,
|
|
Users: map[string]*types.User{"someUser": {
|
|
Id: "someUser",
|
|
}},
|
|
Settings: &types.Settings{},
|
|
}
|
|
|
|
err = store.SaveAccount(context.Background(), account)
|
|
if err != nil {
|
|
t.Fatalf("Error when saving account: %s", err)
|
|
}
|
|
|
|
// this has been validated and parsed by ValidateAndParseToken
|
|
userAuth := nbcontext.UserAuth{
|
|
AccountId: account.Id,
|
|
Domain: domain,
|
|
UserId: userId,
|
|
DomainCategory: "test-category",
|
|
// Groups: []string{"group1", "group2"},
|
|
}
|
|
|
|
// these tests only assert groups are parsed from token as per account settings
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}})
|
|
|
|
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
|
|
|
t.Run("JWT groups disabled", func(t *testing.T) {
|
|
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
|
require.NoError(t, err, "ensure user access by JWT groups failed")
|
|
require.Len(t, userAuth.Groups, 0, "account not enabled to ensure access by groups")
|
|
})
|
|
|
|
t.Run("User impersonated", func(t *testing.T) {
|
|
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
|
require.NoError(t, err, "ensure user access by JWT groups failed")
|
|
require.Len(t, userAuth.Groups, 0, "account not enabled to ensure access by groups")
|
|
})
|
|
|
|
t.Run("User PAT", func(t *testing.T) {
|
|
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
|
require.NoError(t, err, "ensure user access by JWT groups failed")
|
|
require.Len(t, userAuth.Groups, 0, "account not enabled to ensure access by groups")
|
|
})
|
|
|
|
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
|
|
account.Settings.JWTGroupsEnabled = true
|
|
err := store.SaveAccount(context.Background(), account)
|
|
require.NoError(t, err, "save account failed")
|
|
|
|
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
|
require.NoError(t, err, "ensure user access by JWT groups failed")
|
|
require.Len(t, userAuth.Groups, 0, "account missing groups claim name")
|
|
})
|
|
|
|
t.Run("JWT groups enabled without allowed groups", func(t *testing.T) {
|
|
account.Settings.JWTGroupsEnabled = true
|
|
account.Settings.JWTGroupsClaimName = "idp-groups"
|
|
err := store.SaveAccount(context.Background(), account)
|
|
require.NoError(t, err, "save account failed")
|
|
|
|
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
|
require.NoError(t, err, "ensure user access by JWT groups failed")
|
|
require.Equal(t, []string{"group1", "group2"}, userAuth.Groups, "group parsed do not match")
|
|
})
|
|
|
|
t.Run("User in allowed JWT groups", func(t *testing.T) {
|
|
account.Settings.JWTGroupsEnabled = true
|
|
account.Settings.JWTGroupsClaimName = "idp-groups"
|
|
account.Settings.JWTAllowGroups = []string{"group1"}
|
|
err := store.SaveAccount(context.Background(), account)
|
|
require.NoError(t, err, "save account failed")
|
|
|
|
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
|
require.NoError(t, err, "ensure user access by JWT groups failed")
|
|
|
|
require.Equal(t, []string{"group1", "group2"}, userAuth.Groups, "group parsed do not match")
|
|
})
|
|
|
|
t.Run("User not in allowed JWT groups", func(t *testing.T) {
|
|
account.Settings.JWTGroupsEnabled = true
|
|
account.Settings.JWTGroupsClaimName = "idp-groups"
|
|
account.Settings.JWTAllowGroups = []string{"not-a-group"}
|
|
err := store.SaveAccount(context.Background(), account)
|
|
require.NoError(t, err, "save account failed")
|
|
|
|
_, err = manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
|
require.Error(t, err, "ensure user access is not in allowed groups")
|
|
})
|
|
}
|
|
|
|
func TestAuthManager_ValidateAndParseToken(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Add("Cache-Control", "max-age=30") // set a 30s expiry to these keys
|
|
http.ServeFile(w, r, "test_data/jwks.json")
|
|
}))
|
|
defer server.Close()
|
|
|
|
issuer := "http://issuer.local"
|
|
audience := "http://audience.local"
|
|
userIdClaim := "" // defaults to "sub"
|
|
|
|
// we're only testing with RSA256
|
|
keyData, _ := os.ReadFile("test_data/sample_key")
|
|
key, _ := jwt.ParseRSAPrivateKeyFromPEM(keyData)
|
|
keyId := "test-key"
|
|
|
|
// note, we can use a nil store because ValidateAndParseToken does not use it in it's flow
|
|
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false)
|
|
|
|
customClaim := func(name string) string {
|
|
return fmt.Sprintf("%s/%s", audience, name)
|
|
}
|
|
|
|
lastLogin := time.Date(2025, 2, 12, 14, 25, 26, 0, time.UTC) //"2025-02-12T14:25:26.186Z"
|
|
|
|
tests := []struct {
|
|
name string
|
|
tokenFunc func() string
|
|
expected *nbcontext.UserAuth // nil indicates expected error
|
|
}{
|
|
{
|
|
name: "Valid with custom claims",
|
|
tokenFunc: func() string {
|
|
token := jwt.New(jwt.SigningMethodRS256)
|
|
token.Header["kid"] = keyId
|
|
token.Claims = jwt.MapClaims{
|
|
"iss": issuer,
|
|
"aud": []string{audience},
|
|
"iat": time.Now().Unix(),
|
|
"exp": time.Now().Add(time.Hour * 1).Unix(),
|
|
"sub": "user-id|123",
|
|
customClaim(nbjwt.AccountIDSuffix): "account-id|567",
|
|
customClaim(nbjwt.DomainIDSuffix): "http://localhost",
|
|
customClaim(nbjwt.DomainCategorySuffix): "private",
|
|
customClaim(nbjwt.LastLoginSuffix): lastLogin.Format(time.RFC3339),
|
|
customClaim(nbjwt.Invited): false,
|
|
}
|
|
tokenString, _ := token.SignedString(key)
|
|
return tokenString
|
|
},
|
|
expected: &nbcontext.UserAuth{
|
|
UserId: "user-id|123",
|
|
AccountId: "account-id|567",
|
|
Domain: "http://localhost",
|
|
DomainCategory: "private",
|
|
LastLogin: lastLogin,
|
|
Invited: false,
|
|
},
|
|
},
|
|
{
|
|
name: "Valid without custom claims",
|
|
tokenFunc: func() string {
|
|
token := jwt.New(jwt.SigningMethodRS256)
|
|
token.Header["kid"] = keyId
|
|
token.Claims = jwt.MapClaims{
|
|
"iss": issuer,
|
|
"aud": []string{audience},
|
|
"iat": time.Now().Unix(),
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
"sub": "user-id|123",
|
|
}
|
|
tokenString, _ := token.SignedString(key)
|
|
return tokenString
|
|
},
|
|
expected: &nbcontext.UserAuth{
|
|
UserId: "user-id|123",
|
|
},
|
|
},
|
|
{
|
|
name: "Expired token",
|
|
tokenFunc: func() string {
|
|
token := jwt.New(jwt.SigningMethodRS256)
|
|
token.Header["kid"] = keyId
|
|
token.Claims = jwt.MapClaims{
|
|
"iss": issuer,
|
|
"aud": []string{audience},
|
|
"iat": time.Now().Add(time.Hour * -2).Unix(),
|
|
"exp": time.Now().Add(time.Hour * -1).Unix(),
|
|
"sub": "user-id|123",
|
|
}
|
|
tokenString, _ := token.SignedString(key)
|
|
return tokenString
|
|
},
|
|
},
|
|
{
|
|
name: "Not yet valid",
|
|
tokenFunc: func() string {
|
|
token := jwt.New(jwt.SigningMethodRS256)
|
|
token.Header["kid"] = keyId
|
|
token.Claims = jwt.MapClaims{
|
|
"iss": issuer,
|
|
"aud": []string{audience},
|
|
"iat": time.Now().Add(time.Hour).Unix(),
|
|
"exp": time.Now().Add(time.Hour * 2).Unix(),
|
|
"sub": "user-id|123",
|
|
}
|
|
tokenString, _ := token.SignedString(key)
|
|
return tokenString
|
|
},
|
|
},
|
|
{
|
|
name: "Invalid signature",
|
|
tokenFunc: func() string {
|
|
token := jwt.New(jwt.SigningMethodRS256)
|
|
token.Header["kid"] = keyId
|
|
token.Claims = jwt.MapClaims{
|
|
"iss": issuer,
|
|
"aud": []string{audience},
|
|
"iat": time.Now().Unix(),
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
"sub": "user-id|123",
|
|
}
|
|
tokenString, _ := token.SignedString(key)
|
|
parts := strings.Split(tokenString, ".")
|
|
parts[2] = "invalid-signature"
|
|
return strings.Join(parts, ".")
|
|
},
|
|
},
|
|
{
|
|
name: "Invalid issuer",
|
|
tokenFunc: func() string {
|
|
token := jwt.New(jwt.SigningMethodRS256)
|
|
token.Header["kid"] = keyId
|
|
token.Claims = jwt.MapClaims{
|
|
"iss": "not-the-issuer",
|
|
"aud": []string{audience},
|
|
"iat": time.Now().Unix(),
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
"sub": "user-id|123",
|
|
}
|
|
tokenString, _ := token.SignedString(key)
|
|
return tokenString
|
|
},
|
|
},
|
|
{
|
|
name: "Invalid audience",
|
|
tokenFunc: func() string {
|
|
token := jwt.New(jwt.SigningMethodRS256)
|
|
token.Header["kid"] = keyId
|
|
token.Claims = jwt.MapClaims{
|
|
"iss": issuer,
|
|
"aud": []string{"not-the-audience"},
|
|
"iat": time.Now().Unix(),
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
"sub": "user-id|123",
|
|
}
|
|
tokenString, _ := token.SignedString(key)
|
|
return tokenString
|
|
},
|
|
},
|
|
{
|
|
name: "Invalid user claim",
|
|
tokenFunc: func() string {
|
|
token := jwt.New(jwt.SigningMethodRS256)
|
|
token.Header["kid"] = keyId
|
|
token.Claims = jwt.MapClaims{
|
|
"iss": issuer,
|
|
"aud": []string{audience},
|
|
"iat": time.Now().Unix(),
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
"not-sub": "user-id|123",
|
|
}
|
|
tokenString, _ := token.SignedString(key)
|
|
return tokenString
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
tokenString := tt.tokenFunc()
|
|
|
|
userAuth, token, err := manager.ValidateAndParseToken(context.Background(), tokenString)
|
|
|
|
if tt.expected != nil {
|
|
assert.NoError(t, err)
|
|
assert.True(t, token.Valid)
|
|
assert.Equal(t, *tt.expected, userAuth)
|
|
} else {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, token)
|
|
assert.Empty(t, userAuth)
|
|
}
|
|
})
|
|
}
|
|
|
|
}
|