Refactor auth middleware

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-11-01 14:02:09 +03:00
parent e73b5da42b
commit fed48de83f
9 changed files with 102 additions and 104 deletions

View File

@ -88,7 +88,7 @@ type AccountManager interface {
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) GetAccountInfoFromPAT(ctx context.Context, token string) (*User, *PersonalAccessToken, string, string, error)
DeleteAccount(ctx context.Context, accountID, userID string) error DeleteAccount(ctx context.Context, accountID, userID string) error
MarkPATUsed(ctx context.Context, tokenID string) error MarkPATUsed(ctx context.Context, tokenID string) error
GetUserByID(ctx context.Context, id string) (*User, error) GetUserByID(ctx context.Context, id string) (*User, error)
@ -1869,52 +1869,59 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
return am.Store.GetAccount(ctx, accountID) return am.Store.GetAccount(ctx, accountID)
} }
// GetAccountFromPAT returns Account and User associated with a personal access token // GetAccountInfoFromPAT retrieves user, personal access token, domain, and category details from a personal access token.
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) { func (am *DefaultAccountManager) GetAccountInfoFromPAT(ctx context.Context, token string) (user *User, pat *PersonalAccessToken, domain string, category string, err error) {
user, pat, err = am.extractPATFromToken(ctx, token)
if err != nil {
return nil, nil, "", "", err
}
domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, user.AccountID)
if err != nil {
return nil, nil, "", "", err
}
return user, pat, domain, category, nil
}
// extractPATFromToken validates the token structure and retrieves associated User and PAT.
func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*User, *PersonalAccessToken, error) {
if len(token) != PATLength { if len(token) != PATLength {
return nil, nil, nil, fmt.Errorf("token has wrong length") return nil, nil, fmt.Errorf("token has incorrect length")
} }
prefix := token[:len(PATPrefix)] prefix := token[:len(PATPrefix)]
if prefix != PATPrefix { if prefix != PATPrefix {
return nil, nil, nil, fmt.Errorf("token has wrong prefix") return nil, nil, fmt.Errorf("token has incorrect prefix")
} }
secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength] secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength]
encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength] encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength]
verificationChecksum, err := base62.Decode(encodedChecksum) verificationChecksum, err := base62.Decode(encodedChecksum)
if err != nil { if err != nil {
return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err) return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
} }
secretChecksum := crc32.ChecksumIEEE([]byte(secret)) secretChecksum := crc32.ChecksumIEEE([]byte(secret))
if secretChecksum != verificationChecksum { if secretChecksum != verificationChecksum {
return nil, nil, nil, fmt.Errorf("token checksum does not match") return nil, nil, fmt.Errorf("token checksum does not match")
} }
hashedToken := sha256.Sum256([]byte(token)) hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
tokenID, err := am.Store.GetTokenIDByHashedToken(ctx, encodedHashedToken)
pat, err := am.Store.GetPATByHashedToken(ctx, LockingStrengthShare, encodedHashedToken)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
user, err := am.Store.GetUserByTokenID(ctx, tokenID) user, err := am.Store.GetUserByPATID(ctx, LockingStrengthShare, pat.ID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, err
} }
account, err := am.Store.GetAccountByUser(ctx, user.Id) return user, pat, nil
if err != nil {
return nil, nil, nil, err
}
pat := user.PATs[tokenID]
if pat == nil {
return nil, nil, nil, fmt.Errorf("personal access token not found")
}
return account, user, pat, nil
} }
// GetAccountByID returns an account associated with this account ID. // GetAccountByID returns an account associated with this account ID.

View File

@ -47,7 +47,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
) )
authMiddleware := middleware.NewAuthMiddleware( authMiddleware := middleware.NewAuthMiddleware(
accountManager.GetAccountFromPAT, accountManager.GetAccountInfoFromPAT,
jwtValidator.ValidateAndParse, jwtValidator.ValidateAndParse,
accountManager.MarkPATUsed, accountManager.MarkPATUsed,
accountManager.CheckUserAccessByJWTGroups, accountManager.CheckUserAccessByJWTGroups,

View File

@ -9,9 +9,9 @@ import (
"time" "time"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
nbContext "github.com/netbirdio/netbird/management/server/context" nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
@ -19,8 +19,8 @@ import (
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
// GetAccountFromPATFunc function // GetAccountInfoFromPATFunc function
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error)
// ValidateAndParseTokenFunc function // ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error) type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct { type AuthMiddleware struct {
getAccountFromPAT GetAccountFromPATFunc getAccountInfoFromPAT GetAccountInfoFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc markPATUsed MarkPATUsedFunc
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
@ -47,7 +47,7 @@ const (
) )
// NewAuthMiddleware instance constructor // NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor, markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
audience string, userIdClaim string) *AuthMiddleware { audience string, userIdClaim string) *AuthMiddleware {
if userIdClaim == "" { if userIdClaim == "" {
@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
} }
return &AuthMiddleware{ return &AuthMiddleware{
getAccountFromPAT: getAccountFromPAT, getAccountInfoFromPAT: getAccountInfoFromPAT,
validateAndParseToken: validateAndParseToken, validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed, markPATUsed: markPATUsed,
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups, checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
@ -116,7 +116,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
// If an error occurs, call the error handler and return an error // If an error occurs, call the error handler and return an error
if err != nil { if err != nil {
return fmt.Errorf("Error extracting token: %w", err) return fmt.Errorf("error extracting token: %w", err)
} }
validatedToken, err := m.validateAndParseToken(r.Context(), token) validatedToken, err := m.validateAndParseToken(r.Context(), token)
@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j
// CheckPATFromRequest checks if the PAT is valid // CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error { func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
token, err := getTokenFromPATRequest(auth) token, err := getTokenFromPATRequest(auth)
// If an error occurs, call the error handler and return an error
if err != nil { if err != nil {
return fmt.Errorf("Error extracting token: %w", err) return fmt.Errorf("error extracting token: %w", err)
} }
account, user, pat, err := m.getAccountFromPAT(r.Context(), token) user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
if err != nil { if err != nil {
return fmt.Errorf("invalid Token: %w", err) return fmt.Errorf("invalid Token: %w", err)
} }
@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
claimMaps := jwt.MapClaims{} claimMaps := jwt.MapClaims{}
claimMaps[m.userIDClaim] = user.Id claimMaps[m.userIDClaim] = user.Id
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
// Update the current request with the new context information. // Update the current request with the new context information.

View File

@ -55,7 +55,7 @@ type MockAccountManager struct {
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) GetAccountInfoFromPATFunc func(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error)
MarkPATUsedFunc func(ctx context.Context, pat string) error MarkPATUsedFunc func(ctx context.Context, pat string) error
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
@ -235,12 +235,12 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
} }
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface // GetAccountInfoFromPAT mock implementation of GetAccountInfoFromPAT from server.AccountManager interface
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { func (am *MockAccountManager) GetAccountInfoFromPAT(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error) {
if am.GetAccountFromPATFunc != nil { if am.GetAccountInfoFromPATFunc != nil {
return am.GetAccountFromPATFunc(ctx, pat) return am.GetAccountInfoFromPATFunc(ctx, token)
} }
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetAccountInfoFromPAT is not implemented")
} }
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface // DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface

View File

@ -475,49 +475,6 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
return s.GetAccount(ctx, key.AccountID) return s.GetAccount(ctx, key.AccountID)
} }
func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) {
var token PersonalAccessToken
result := s.db.First(&token, "hashed_token = ?", hashedToken)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
return "", status.NewGetAccountFromStoreError(result.Error)
}
return token.ID, nil
}
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) {
var token PersonalAccessToken
result := s.db.First(&token, idQueryCondition, tokenID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
return nil, status.NewGetAccountFromStoreError(result.Error)
}
if token.UserID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
var user User
result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATsG))
for _, pat := range user.PATsG {
user.PATs[pat.ID] = pat.Copy()
}
return &user, nil
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) { func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
var user User var user User
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
@ -526,6 +483,23 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID) return nil, status.NewUserNotFoundError(userID)
} }
log.WithContext(ctx).Errorf("failed to get user from the store: %s", result.Error)
return nil, status.NewGetUserFromStoreError()
}
return &user, nil
}
func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error) {
var user User
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
Where("personal_access_tokens.id = ?", patID).First(&user)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewPATNotFoundError()
}
log.WithContext(ctx).Errorf("failed to get token user from the store: %s", result.Error)
return nil, status.NewGetUserFromStoreError() return nil, status.NewGetUserFromStoreError()
} }
@ -1635,6 +1609,21 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki
return nil return nil
} }
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error) {
var pat PersonalAccessToken
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewPATNotFoundError()
}
log.WithContext(ctx).Errorf("failed to get pat from the store: %s", result.Error)
return nil, status.NewGetPATFromStoreError()
}
return &pat, nil
}
// GetPATByID retrieves a personal access token by its ID and user ID. // GetPATByID retrieves a personal access token by its ID and user ID.
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*PersonalAccessToken, error) { func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*PersonalAccessToken, error) {
var pat PersonalAccessToken var pat PersonalAccessToken
@ -1642,10 +1631,10 @@ func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength,
First(&pat, "id = ? AND user_id = ?", patID, userID) First(&pat, "id = ? AND user_id = ?", patID, userID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "PAT not found") return nil, status.NewPATNotFoundError()
} }
log.WithContext(ctx).Errorf("failed to get PAT from the store: %s", err) log.WithContext(ctx).Errorf("failed to get PAT from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get PAT from store") return nil, status.NewGetPATFromStoreError()
} }
return &pat, nil return &pat, nil

View File

@ -572,11 +572,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
hashed := "SoMeHaShEdToKeN" hashed := "SoMeHaShEdToKeN"
id := "9dj38s35-63fb-11ec-90d6-0242ac120003" id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
token, err := store.GetTokenIDByHashedToken(context.Background(), hashed) pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, hashed)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, id, token) require.Equal(t, id, pat.ID)
_, err = store.GetTokenIDByHashedToken(context.Background(), "non-existing-hash") _, err = store.GetPATByHashedToken(context.Background(), LockingStrengthShare, "non-existing-hash")
require.Error(t, err) require.Error(t, err)
parsedErr, ok := status.FromError(err) parsedErr, ok := status.FromError(err)
require.True(t, ok) require.True(t, ok)
@ -595,11 +595,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
id := "9dj38s35-63fb-11ec-90d6-0242ac120003" id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByTokenID(context.Background(), id) user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID) require.Equal(t, id, user.PATs[id].ID)
_, err = store.GetUserByTokenID(context.Background(), "non-existing-id") _, err = store.GetUserByPATID(context.Background(), LockingStrengthShare, "non-existing-id")
require.Error(t, err) require.Error(t, err)
parsedErr, ok := status.FromError(err) parsedErr, ok := status.FromError(err)
require.True(t, ok) require.True(t, ok)
@ -967,9 +967,9 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
hashed := "SoMeHaShEdToKeN" hashed := "SoMeHaShEdToKeN"
id := "9dj38s35-63fb-11ec-90d6-0242ac120003" id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
token, err := store.GetTokenIDByHashedToken(context.Background(), hashed) pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, hashed)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, id, token) require.Equal(t, id, pat.ID)
} }
func TestPostgresql_GetUserByTokenID(t *testing.T) { func TestPostgresql_GetUserByTokenID(t *testing.T) {
@ -984,7 +984,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
id := "9dj38s35-63fb-11ec-90d6-0242ac120003" id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
user, err := store.GetUserByTokenID(context.Background(), id) user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID) require.Equal(t, id, user.PATs[id].ID)
} }

View File

@ -146,6 +146,10 @@ func NewPATNotFoundError() error {
return Errorf(NotFound, "PAT not found") return Errorf(NotFound, "PAT not found")
} }
func NewGetPATFromStoreError() error {
return Errorf(Internal, "issue getting pat from store")
}
func NewUnauthorizedToViewPATsError() error { func NewUnauthorizedToViewPATsError() error {
return Errorf(PermissionDenied, "only users with admin power can view PATs") return Errorf(PermissionDenied, "only users with admin power can view PATs")
} }

View File

@ -63,13 +63,12 @@ type Store interface {
DeleteAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error)
@ -125,6 +124,7 @@ type Store interface {
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error)
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error)
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error)
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error

View File

@ -55,25 +55,25 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
eventStore: &activity.InMemoryEventStore{}, eventStore: &activity.InMemoryEventStore{},
} }
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) newPAT, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
if err != nil { if err != nil {
t.Fatalf("Error when adding PAT to user: %s", err) t.Fatalf("Error when adding PAT to user: %s", err)
} }
assert.Equal(t, pat.CreatedBy, mockUserID) assert.Equal(t, newPAT.CreatedBy, mockUserID)
tokenID, err := am.Store.GetTokenIDByHashedToken(context.Background(), pat.HashedToken) pat, err := am.Store.GetPATByHashedToken(context.Background(), LockingStrengthShare, newPAT.HashedToken)
if err != nil { if err != nil {
t.Fatalf("Error when getting token ID by hashed token: %s", err) t.Fatalf("Error when getting token ID by hashed token: %s", err)
} }
if tokenID == "" { if pat.ID == "" {
t.Fatal("GetTokenIDByHashedToken failed after adding PAT") t.Fatal("GetTokenIDByHashedToken failed after adding PAT")
} }
assert.Equal(t, pat.ID, tokenID) assert.Equal(t, newPAT.ID, pat.ID)
user, err := am.Store.GetUserByTokenID(context.Background(), tokenID) user, err := am.Store.GetUserByPATID(context.Background(), LockingStrengthShare, pat.ID)
if err != nil { if err != nil {
t.Fatalf("Error when getting user by token ID: %s", err) t.Fatalf("Error when getting user by token ID: %s", err)
} }