mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-04 01:41:17 +01:00
Refactor auth middleware
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
e73b5da42b
commit
fed48de83f
@ -88,7 +88,7 @@ type AccountManager interface {
|
||||
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
||||
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||
GetAccountInfoFromPAT(ctx context.Context, token string) (*User, *PersonalAccessToken, string, string, error)
|
||||
DeleteAccount(ctx context.Context, accountID, userID string) error
|
||||
MarkPATUsed(ctx context.Context, tokenID string) error
|
||||
GetUserByID(ctx context.Context, id string) (*User, error)
|
||||
@ -1869,52 +1869,59 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
|
||||
return am.Store.GetAccount(ctx, accountID)
|
||||
}
|
||||
|
||||
// GetAccountFromPAT returns Account and User associated with a personal access token
|
||||
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) {
|
||||
// GetAccountInfoFromPAT retrieves user, personal access token, domain, and category details from a personal access token.
|
||||
func (am *DefaultAccountManager) GetAccountInfoFromPAT(ctx context.Context, token string) (user *User, pat *PersonalAccessToken, domain string, category string, err error) {
|
||||
user, pat, err = am.extractPATFromToken(ctx, token)
|
||||
if err != nil {
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
|
||||
domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, user.AccountID)
|
||||
if err != nil {
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
|
||||
return user, pat, domain, category, nil
|
||||
}
|
||||
|
||||
// extractPATFromToken validates the token structure and retrieves associated User and PAT.
|
||||
func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*User, *PersonalAccessToken, error) {
|
||||
if len(token) != PATLength {
|
||||
return nil, nil, nil, fmt.Errorf("token has wrong length")
|
||||
return nil, nil, fmt.Errorf("token has incorrect length")
|
||||
}
|
||||
|
||||
prefix := token[:len(PATPrefix)]
|
||||
if prefix != PATPrefix {
|
||||
return nil, nil, nil, fmt.Errorf("token has wrong prefix")
|
||||
return nil, nil, fmt.Errorf("token has incorrect prefix")
|
||||
}
|
||||
|
||||
secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength]
|
||||
encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength]
|
||||
|
||||
verificationChecksum, err := base62.Decode(encodedChecksum)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
|
||||
return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
|
||||
}
|
||||
|
||||
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
|
||||
if secretChecksum != verificationChecksum {
|
||||
return nil, nil, nil, fmt.Errorf("token checksum does not match")
|
||||
return nil, nil, fmt.Errorf("token checksum does not match")
|
||||
}
|
||||
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
tokenID, err := am.Store.GetTokenIDByHashedToken(ctx, encodedHashedToken)
|
||||
|
||||
pat, err := am.Store.GetPATByHashedToken(ctx, LockingStrengthShare, encodedHashedToken)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByTokenID(ctx, tokenID)
|
||||
user, err := am.Store.GetUserByPATID(ctx, LockingStrengthShare, pat.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccountByUser(ctx, user.Id)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
pat := user.PATs[tokenID]
|
||||
if pat == nil {
|
||||
return nil, nil, nil, fmt.Errorf("personal access token not found")
|
||||
}
|
||||
|
||||
return account, user, pat, nil
|
||||
return user, pat, nil
|
||||
}
|
||||
|
||||
// GetAccountByID returns an account associated with this account ID.
|
||||
|
@ -47,7 +47,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
|
||||
)
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
accountManager.GetAccountFromPAT,
|
||||
accountManager.GetAccountInfoFromPAT,
|
||||
jwtValidator.ValidateAndParse,
|
||||
accountManager.MarkPATUsed,
|
||||
accountManager.CheckUserAccessByJWTGroups,
|
||||
|
@ -9,9 +9,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
@ -19,8 +19,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
// GetAccountFromPATFunc function
|
||||
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
// GetAccountInfoFromPATFunc function
|
||||
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error)
|
||||
|
||||
// ValidateAndParseTokenFunc function
|
||||
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
|
||||
@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
getAccountFromPAT GetAccountFromPATFunc
|
||||
getAccountInfoFromPAT GetAccountInfoFromPATFunc
|
||||
validateAndParseToken ValidateAndParseTokenFunc
|
||||
markPATUsed MarkPATUsedFunc
|
||||
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
|
||||
@ -47,7 +47,7 @@ const (
|
||||
)
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
||||
audience string, userIdClaim string) *AuthMiddleware {
|
||||
if userIdClaim == "" {
|
||||
@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
|
||||
}
|
||||
|
||||
return &AuthMiddleware{
|
||||
getAccountFromPAT: getAccountFromPAT,
|
||||
getAccountInfoFromPAT: getAccountInfoFromPAT,
|
||||
validateAndParseToken: validateAndParseToken,
|
||||
markPATUsed: markPATUsed,
|
||||
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
|
||||
@ -116,7 +116,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
validatedToken, err := m.validateAndParseToken(r.Context(), token)
|
||||
@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
||||
token, err := getTokenFromPATRequest(auth)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
|
||||
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid Token: %w", err)
|
||||
}
|
||||
@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
claimMaps := jwt.MapClaims{}
|
||||
claimMaps[m.userIDClaim] = user.Id
|
||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
|
||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
|
||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
|
||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
|
||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
|
||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
|
||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
|
||||
// Update the current request with the new context information.
|
||||
|
@ -55,7 +55,7 @@ type MockAccountManager struct {
|
||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
||||
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
GetAccountInfoFromPATFunc func(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error)
|
||||
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
@ -235,12 +235,12 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
if am.GetAccountFromPATFunc != nil {
|
||||
return am.GetAccountFromPATFunc(ctx, pat)
|
||||
// GetAccountInfoFromPAT mock implementation of GetAccountInfoFromPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountInfoFromPAT(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error) {
|
||||
if am.GetAccountInfoFromPATFunc != nil {
|
||||
return am.GetAccountInfoFromPATFunc(ctx, token)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented")
|
||||
return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetAccountInfoFromPAT is not implemented")
|
||||
}
|
||||
|
||||
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
|
||||
|
@ -475,49 +475,6 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
|
||||
return s.GetAccount(ctx, key.AccountID)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) {
|
||||
var token PersonalAccessToken
|
||||
result := s.db.First(&token, "hashed_token = ?", hashedToken)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
return token.ID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) {
|
||||
var token PersonalAccessToken
|
||||
result := s.db.First(&token, idQueryCondition, tokenID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
if token.UserID == "" {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
|
||||
var user User
|
||||
result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID)
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
|
||||
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATsG))
|
||||
for _, pat := range user.PATsG {
|
||||
user.PATs[pat.ID] = pat.Copy()
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
||||
var user User
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
@ -526,6 +483,23 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewUserNotFoundError(userID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get user from the store: %s", result.Error)
|
||||
return nil, status.NewGetUserFromStoreError()
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error) {
|
||||
var user User
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
|
||||
Where("personal_access_tokens.id = ?", patID).First(&user)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewPATNotFoundError()
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get token user from the store: %s", result.Error)
|
||||
return nil, status.NewGetUserFromStoreError()
|
||||
}
|
||||
|
||||
@ -1635,6 +1609,21 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
|
||||
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error) {
|
||||
var pat PersonalAccessToken
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewPATNotFoundError()
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get pat from the store: %s", result.Error)
|
||||
return nil, status.NewGetPATFromStoreError()
|
||||
}
|
||||
|
||||
return &pat, nil
|
||||
}
|
||||
|
||||
// GetPATByID retrieves a personal access token by its ID and user ID.
|
||||
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*PersonalAccessToken, error) {
|
||||
var pat PersonalAccessToken
|
||||
@ -1642,10 +1631,10 @@ func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength,
|
||||
First(&pat, "id = ? AND user_id = ?", patID, userID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "PAT not found")
|
||||
return nil, status.NewPATNotFoundError()
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get PAT from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get PAT from store")
|
||||
return nil, status.NewGetPATFromStoreError()
|
||||
}
|
||||
|
||||
return &pat, nil
|
||||
|
@ -572,11 +572,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
|
||||
hashed := "SoMeHaShEdToKeN"
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
token, err := store.GetTokenIDByHashedToken(context.Background(), hashed)
|
||||
pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, hashed)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, token)
|
||||
require.Equal(t, id, pat.ID)
|
||||
|
||||
_, err = store.GetTokenIDByHashedToken(context.Background(), "non-existing-hash")
|
||||
_, err = store.GetPATByHashedToken(context.Background(), LockingStrengthShare, "non-existing-hash")
|
||||
require.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
@ -595,11 +595,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
|
||||
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
user, err := store.GetUserByTokenID(context.Background(), id)
|
||||
user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, user.PATs[id].ID)
|
||||
|
||||
_, err = store.GetUserByTokenID(context.Background(), "non-existing-id")
|
||||
_, err = store.GetUserByPATID(context.Background(), LockingStrengthShare, "non-existing-id")
|
||||
require.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
@ -967,9 +967,9 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
|
||||
hashed := "SoMeHaShEdToKeN"
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
token, err := store.GetTokenIDByHashedToken(context.Background(), hashed)
|
||||
pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, hashed)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, token)
|
||||
require.Equal(t, id, pat.ID)
|
||||
}
|
||||
|
||||
func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
||||
@ -984,7 +984,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
||||
|
||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||
|
||||
user, err := store.GetUserByTokenID(context.Background(), id)
|
||||
user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, id, user.PATs[id].ID)
|
||||
}
|
||||
|
@ -146,6 +146,10 @@ func NewPATNotFoundError() error {
|
||||
return Errorf(NotFound, "PAT not found")
|
||||
}
|
||||
|
||||
func NewGetPATFromStoreError() error {
|
||||
return Errorf(Internal, "issue getting pat from store")
|
||||
}
|
||||
|
||||
func NewUnauthorizedToViewPATsError() error {
|
||||
return Errorf(PermissionDenied, "only users with admin power can view PATs")
|
||||
}
|
||||
|
@ -63,13 +63,12 @@ type Store interface {
|
||||
DeleteAccount(ctx context.Context, account *Account) error
|
||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
|
||||
SaveUsers(accountID string, users map[string]*User) error
|
||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||
|
||||
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
|
||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error)
|
||||
@ -125,6 +124,7 @@ type Store interface {
|
||||
|
||||
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error)
|
||||
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error)
|
||||
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error)
|
||||
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
|
||||
SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error
|
||||
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
|
||||
|
@ -55,25 +55,25 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
}
|
||||
|
||||
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
|
||||
newPAT, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when adding PAT to user: %s", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, pat.CreatedBy, mockUserID)
|
||||
assert.Equal(t, newPAT.CreatedBy, mockUserID)
|
||||
|
||||
tokenID, err := am.Store.GetTokenIDByHashedToken(context.Background(), pat.HashedToken)
|
||||
pat, err := am.Store.GetPATByHashedToken(context.Background(), LockingStrengthShare, newPAT.HashedToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting token ID by hashed token: %s", err)
|
||||
}
|
||||
|
||||
if tokenID == "" {
|
||||
if pat.ID == "" {
|
||||
t.Fatal("GetTokenIDByHashedToken failed after adding PAT")
|
||||
}
|
||||
|
||||
assert.Equal(t, pat.ID, tokenID)
|
||||
assert.Equal(t, newPAT.ID, pat.ID)
|
||||
|
||||
user, err := am.Store.GetUserByTokenID(context.Background(), tokenID)
|
||||
user, err := am.Store.GetUserByPATID(context.Background(), LockingStrengthShare, pat.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting user by token ID: %s", err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user