Refactor auth middleware

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-11-01 14:02:09 +03:00
parent e73b5da42b
commit fed48de83f
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
9 changed files with 102 additions and 104 deletions

View File

@ -88,7 +88,7 @@ type AccountManager interface {
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
GetAccountInfoFromPAT(ctx context.Context, token string) (*User, *PersonalAccessToken, string, string, error)
DeleteAccount(ctx context.Context, accountID, userID string) error
MarkPATUsed(ctx context.Context, tokenID string) error
GetUserByID(ctx context.Context, id string) (*User, error)
@ -1869,52 +1869,59 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
return am.Store.GetAccount(ctx, accountID)
}
// GetAccountFromPAT returns Account and User associated with a personal access token
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) {
// GetAccountInfoFromPAT retrieves user, personal access token, domain, and category details from a personal access token.
func (am *DefaultAccountManager) GetAccountInfoFromPAT(ctx context.Context, token string) (user *User, pat *PersonalAccessToken, domain string, category string, err error) {
user, pat, err = am.extractPATFromToken(ctx, token)
if err != nil {
return nil, nil, "", "", err
}
domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, user.AccountID)
if err != nil {
return nil, nil, "", "", err
}
return user, pat, domain, category, nil
}
// extractPATFromToken validates the token structure and retrieves associated User and PAT.
func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*User, *PersonalAccessToken, error) {
if len(token) != PATLength {
return nil, nil, nil, fmt.Errorf("token has wrong length")
return nil, nil, fmt.Errorf("token has incorrect length")
}
prefix := token[:len(PATPrefix)]
if prefix != PATPrefix {
return nil, nil, nil, fmt.Errorf("token has wrong prefix")
return nil, nil, fmt.Errorf("token has incorrect prefix")
}
secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength]
encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength]
verificationChecksum, err := base62.Decode(encodedChecksum)
if err != nil {
return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
}
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
if secretChecksum != verificationChecksum {
return nil, nil, nil, fmt.Errorf("token checksum does not match")
return nil, nil, fmt.Errorf("token checksum does not match")
}
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
tokenID, err := am.Store.GetTokenIDByHashedToken(ctx, encodedHashedToken)
pat, err := am.Store.GetPATByHashedToken(ctx, LockingStrengthShare, encodedHashedToken)
if err != nil {
return nil, nil, nil, err
return nil, nil, err
}
user, err := am.Store.GetUserByTokenID(ctx, tokenID)
user, err := am.Store.GetUserByPATID(ctx, LockingStrengthShare, pat.ID)
if err != nil {
return nil, nil, nil, err
return nil, nil, err
}
account, err := am.Store.GetAccountByUser(ctx, user.Id)
if err != nil {
return nil, nil, nil, err
}
pat := user.PATs[tokenID]
if pat == nil {
return nil, nil, nil, fmt.Errorf("personal access token not found")
}
return account, user, pat, nil
return user, pat, nil
}
// GetAccountByID returns an account associated with this account ID.

View File

@ -47,7 +47,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
)
authMiddleware := middleware.NewAuthMiddleware(
accountManager.GetAccountFromPAT,
accountManager.GetAccountInfoFromPAT,
jwtValidator.ValidateAndParse,
accountManager.MarkPATUsed,
accountManager.CheckUserAccessByJWTGroups,

View File

@ -9,9 +9,9 @@ import (
"time"
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
@ -19,8 +19,8 @@ import (
"github.com/netbirdio/netbird/management/server/status"
)
// GetAccountFromPATFunc function
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
// GetAccountInfoFromPATFunc function
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error)
// ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
getAccountFromPAT GetAccountFromPATFunc
getAccountInfoFromPAT GetAccountInfoFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
@ -47,7 +47,7 @@ const (
)
// NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
audience string, userIdClaim string) *AuthMiddleware {
if userIdClaim == "" {
@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
}
return &AuthMiddleware{
getAccountFromPAT: getAccountFromPAT,
getAccountInfoFromPAT: getAccountInfoFromPAT,
validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed,
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
@ -116,7 +116,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
return fmt.Errorf("error extracting token: %w", err)
}
validatedToken, err := m.validateAndParseToken(r.Context(), token)
@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
token, err := getTokenFromPATRequest(auth)
// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
return fmt.Errorf("error extracting token: %w", err)
}
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
if err != nil {
return fmt.Errorf("invalid Token: %w", err)
}
@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
claimMaps := jwt.MapClaims{}
claimMaps[m.userIDClaim] = user.Id
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
// Update the current request with the new context information.

View File

@ -55,7 +55,7 @@ type MockAccountManager struct {
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
GetAccountInfoFromPATFunc func(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error)
MarkPATUsedFunc func(ctx context.Context, pat string) error
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
@ -235,12 +235,12 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
if am.GetAccountFromPATFunc != nil {
return am.GetAccountFromPATFunc(ctx, pat)
// GetAccountInfoFromPAT mock implementation of GetAccountInfoFromPAT from server.AccountManager interface
func (am *MockAccountManager) GetAccountInfoFromPAT(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error) {
if am.GetAccountInfoFromPATFunc != nil {
return am.GetAccountInfoFromPATFunc(ctx, token)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented")
return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetAccountInfoFromPAT is not implemented")
}
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface

View File

@ -475,49 +475,6 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
return s.GetAccount(ctx, key.AccountID)
}
func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) {
var token PersonalAccessToken
result := s.db.First(&token, "hashed_token = ?", hashedToken)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
return "", status.NewGetAccountFromStoreError(result.Error)
}
return token.ID, nil
}
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) {
var token PersonalAccessToken
result := s.db.First(&token, idQueryCondition, tokenID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
return nil, status.NewGetAccountFromStoreError(result.Error)
}
if token.UserID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
var user User
result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATsG))
for _, pat := range user.PATsG {
user.PATs[pat.ID] = pat.Copy()
}
return &user, nil
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
var user User
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
@ -526,6 +483,23 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
}
log.WithContext(ctx).Errorf("failed to get user from the store: %s", result.Error)
return nil, status.NewGetUserFromStoreError()
}
return &user, nil
}
func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error) {
var user User
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
Where("personal_access_tokens.id = ?", patID).First(&user)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewPATNotFoundError()
}
log.WithContext(ctx).Errorf("failed to get token user from the store: %s", result.Error)
return nil, status.NewGetUserFromStoreError()
}
@ -1635,6 +1609,21 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki
return nil
}
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error) {
var pat PersonalAccessToken
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewPATNotFoundError()
}
log.WithContext(ctx).Errorf("failed to get pat from the store: %s", result.Error)
return nil, status.NewGetPATFromStoreError()
}
return &pat, nil
}
// GetPATByID retrieves a personal access token by its ID and user ID.
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*PersonalAccessToken, error) {
var pat PersonalAccessToken
@ -1642,10 +1631,10 @@ func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength,
First(&pat, "id = ? AND user_id = ?", patID, userID)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "PAT not found")
return nil, status.NewPATNotFoundError()
}
log.WithContext(ctx).Errorf("failed to get PAT from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get PAT from store")
return nil, status.NewGetPATFromStoreError()
}
return &pat, nil

View File

@ -572,11 +572,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
hashed := "SoMeHaShEdToKeN"
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
token, err := store.GetTokenIDByHashedToken(context.Background(), hashed)
pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, hashed)
require.NoError(t, err)
require.Equal(t, id, token)
require.Equal(t, id, pat.ID)
_, err = store.GetTokenIDByHashedToken(context.Background(), "non-existing-hash")
_, err = store.GetPATByHashedToken(context.Background(), LockingStrengthShare, "non-existing-hash")
require.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
@ -595,11 +595,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByTokenID(context.Background(), id)
user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID)
_, err = store.GetUserByTokenID(context.Background(), "non-existing-id")
_, err = store.GetUserByPATID(context.Background(), LockingStrengthShare, "non-existing-id")
require.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
@ -967,9 +967,9 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
hashed := "SoMeHaShEdToKeN"
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
token, err := store.GetTokenIDByHashedToken(context.Background(), hashed)
pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, hashed)
require.NoError(t, err)
require.Equal(t, id, token)
require.Equal(t, id, pat.ID)
}
func TestPostgresql_GetUserByTokenID(t *testing.T) {
@ -984,7 +984,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByTokenID(context.Background(), id)
user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID)
}

View File

@ -146,6 +146,10 @@ func NewPATNotFoundError() error {
return Errorf(NotFound, "PAT not found")
}
func NewGetPATFromStoreError() error {
return Errorf(Internal, "issue getting pat from store")
}
func NewUnauthorizedToViewPATsError() error {
return Errorf(PermissionDenied, "only users with admin power can view PATs")
}

View File

@ -63,13 +63,12 @@ type Store interface {
DeleteAccount(ctx context.Context, account *Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error)
@ -125,6 +124,7 @@ type Store interface {
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error)
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error)
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error)
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error

View File

@ -55,25 +55,25 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
eventStore: &activity.InMemoryEventStore{},
}
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
newPAT, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
if err != nil {
t.Fatalf("Error when adding PAT to user: %s", err)
}
assert.Equal(t, pat.CreatedBy, mockUserID)
assert.Equal(t, newPAT.CreatedBy, mockUserID)
tokenID, err := am.Store.GetTokenIDByHashedToken(context.Background(), pat.HashedToken)
pat, err := am.Store.GetPATByHashedToken(context.Background(), LockingStrengthShare, newPAT.HashedToken)
if err != nil {
t.Fatalf("Error when getting token ID by hashed token: %s", err)
}
if tokenID == "" {
if pat.ID == "" {
t.Fatal("GetTokenIDByHashedToken failed after adding PAT")
}
assert.Equal(t, pat.ID, tokenID)
assert.Equal(t, newPAT.ID, pat.ID)
user, err := am.Store.GetUserByTokenID(context.Background(), tokenID)
user, err := am.Store.GetUserByPATID(context.Background(), LockingStrengthShare, pat.ID)
if err != nil {
t.Fatalf("Error when getting user by token ID: %s", err)
}