mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-15 01:32:56 +02:00
Refactor auth middleware
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
@ -88,7 +88,7 @@ type AccountManager interface {
|
|||||||
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
||||||
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||||
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||||
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
|
GetAccountInfoFromPAT(ctx context.Context, token string) (*User, *PersonalAccessToken, string, string, error)
|
||||||
DeleteAccount(ctx context.Context, accountID, userID string) error
|
DeleteAccount(ctx context.Context, accountID, userID string) error
|
||||||
MarkPATUsed(ctx context.Context, tokenID string) error
|
MarkPATUsed(ctx context.Context, tokenID string) error
|
||||||
GetUserByID(ctx context.Context, id string) (*User, error)
|
GetUserByID(ctx context.Context, id string) (*User, error)
|
||||||
@ -1869,52 +1869,59 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
|
|||||||
return am.Store.GetAccount(ctx, accountID)
|
return am.Store.GetAccount(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountFromPAT returns Account and User associated with a personal access token
|
// GetAccountInfoFromPAT retrieves user, personal access token, domain, and category details from a personal access token.
|
||||||
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) {
|
func (am *DefaultAccountManager) GetAccountInfoFromPAT(ctx context.Context, token string) (user *User, pat *PersonalAccessToken, domain string, category string, err error) {
|
||||||
|
user, pat, err = am.extractPATFromToken(ctx, token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, user.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, pat, domain, category, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractPATFromToken validates the token structure and retrieves associated User and PAT.
|
||||||
|
func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*User, *PersonalAccessToken, error) {
|
||||||
if len(token) != PATLength {
|
if len(token) != PATLength {
|
||||||
return nil, nil, nil, fmt.Errorf("token has wrong length")
|
return nil, nil, fmt.Errorf("token has incorrect length")
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix := token[:len(PATPrefix)]
|
prefix := token[:len(PATPrefix)]
|
||||||
if prefix != PATPrefix {
|
if prefix != PATPrefix {
|
||||||
return nil, nil, nil, fmt.Errorf("token has wrong prefix")
|
return nil, nil, fmt.Errorf("token has incorrect prefix")
|
||||||
}
|
}
|
||||||
|
|
||||||
secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength]
|
secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength]
|
||||||
encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength]
|
encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength]
|
||||||
|
|
||||||
verificationChecksum, err := base62.Decode(encodedChecksum)
|
verificationChecksum, err := base62.Decode(encodedChecksum)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
|
return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
|
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
|
||||||
if secretChecksum != verificationChecksum {
|
if secretChecksum != verificationChecksum {
|
||||||
return nil, nil, nil, fmt.Errorf("token checksum does not match")
|
return nil, nil, fmt.Errorf("token checksum does not match")
|
||||||
}
|
}
|
||||||
|
|
||||||
hashedToken := sha256.Sum256([]byte(token))
|
hashedToken := sha256.Sum256([]byte(token))
|
||||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||||
tokenID, err := am.Store.GetTokenIDByHashedToken(ctx, encodedHashedToken)
|
|
||||||
|
pat, err := am.Store.GetPATByHashedToken(ctx, LockingStrengthShare, encodedHashedToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := am.Store.GetUserByTokenID(ctx, tokenID)
|
user, err := am.Store.GetUserByPATID(ctx, LockingStrengthShare, pat.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetAccountByUser(ctx, user.Id)
|
return user, pat, nil
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pat := user.PATs[tokenID]
|
|
||||||
if pat == nil {
|
|
||||||
return nil, nil, nil, fmt.Errorf("personal access token not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
return account, user, pat, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountByID returns an account associated with this account ID.
|
// GetAccountByID returns an account associated with this account ID.
|
||||||
|
@ -47,7 +47,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
|
|||||||
)
|
)
|
||||||
|
|
||||||
authMiddleware := middleware.NewAuthMiddleware(
|
authMiddleware := middleware.NewAuthMiddleware(
|
||||||
accountManager.GetAccountFromPAT,
|
accountManager.GetAccountInfoFromPAT,
|
||||||
jwtValidator.ValidateAndParse,
|
jwtValidator.ValidateAndParse,
|
||||||
accountManager.MarkPATUsed,
|
accountManager.MarkPATUsed,
|
||||||
accountManager.CheckUserAccessByJWTGroups,
|
accountManager.CheckUserAccessByJWTGroups,
|
||||||
|
@ -9,9 +9,9 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/netbirdio/netbird/management/server"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
@ -19,8 +19,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetAccountFromPATFunc function
|
// GetAccountInfoFromPATFunc function
|
||||||
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error)
|
||||||
|
|
||||||
// ValidateAndParseTokenFunc function
|
// ValidateAndParseTokenFunc function
|
||||||
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
|
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
|
||||||
@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A
|
|||||||
|
|
||||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||||
type AuthMiddleware struct {
|
type AuthMiddleware struct {
|
||||||
getAccountFromPAT GetAccountFromPATFunc
|
getAccountInfoFromPAT GetAccountInfoFromPATFunc
|
||||||
validateAndParseToken ValidateAndParseTokenFunc
|
validateAndParseToken ValidateAndParseTokenFunc
|
||||||
markPATUsed MarkPATUsedFunc
|
markPATUsed MarkPATUsedFunc
|
||||||
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
|
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
|
||||||
@ -47,7 +47,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewAuthMiddleware instance constructor
|
// NewAuthMiddleware instance constructor
|
||||||
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
|
||||||
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
|
||||||
audience string, userIdClaim string) *AuthMiddleware {
|
audience string, userIdClaim string) *AuthMiddleware {
|
||||||
if userIdClaim == "" {
|
if userIdClaim == "" {
|
||||||
@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &AuthMiddleware{
|
return &AuthMiddleware{
|
||||||
getAccountFromPAT: getAccountFromPAT,
|
getAccountInfoFromPAT: getAccountInfoFromPAT,
|
||||||
validateAndParseToken: validateAndParseToken,
|
validateAndParseToken: validateAndParseToken,
|
||||||
markPATUsed: markPATUsed,
|
markPATUsed: markPATUsed,
|
||||||
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
|
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
|
||||||
@ -116,7 +116,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
|
|||||||
|
|
||||||
// If an error occurs, call the error handler and return an error
|
// If an error occurs, call the error handler and return an error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error extracting token: %w", err)
|
return fmt.Errorf("error extracting token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
validatedToken, err := m.validateAndParseToken(r.Context(), token)
|
validatedToken, err := m.validateAndParseToken(r.Context(), token)
|
||||||
@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j
|
|||||||
// CheckPATFromRequest checks if the PAT is valid
|
// CheckPATFromRequest checks if the PAT is valid
|
||||||
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
|
||||||
token, err := getTokenFromPATRequest(auth)
|
token, err := getTokenFromPATRequest(auth)
|
||||||
|
|
||||||
// If an error occurs, call the error handler and return an error
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error extracting token: %w", err)
|
return fmt.Errorf("error extracting token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
|
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid Token: %w", err)
|
return fmt.Errorf("invalid Token: %w", err)
|
||||||
}
|
}
|
||||||
@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
|
|||||||
|
|
||||||
claimMaps := jwt.MapClaims{}
|
claimMaps := jwt.MapClaims{}
|
||||||
claimMaps[m.userIDClaim] = user.Id
|
claimMaps[m.userIDClaim] = user.Id
|
||||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
|
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
|
||||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
|
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
|
||||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
|
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
|
||||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||||
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
|
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
|
||||||
// Update the current request with the new context information.
|
// Update the current request with the new context information.
|
||||||
|
@ -55,7 +55,7 @@ type MockAccountManager struct {
|
|||||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
||||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
||||||
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
GetAccountInfoFromPATFunc func(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error)
|
||||||
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
||||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||||
@ -235,12 +235,12 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str
|
|||||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
|
// GetAccountInfoFromPAT mock implementation of GetAccountInfoFromPAT from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
func (am *MockAccountManager) GetAccountInfoFromPAT(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error) {
|
||||||
if am.GetAccountFromPATFunc != nil {
|
if am.GetAccountInfoFromPATFunc != nil {
|
||||||
return am.GetAccountFromPATFunc(ctx, pat)
|
return am.GetAccountInfoFromPATFunc(ctx, token)
|
||||||
}
|
}
|
||||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented")
|
return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetAccountInfoFromPAT is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
|
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
|
||||||
|
@ -475,49 +475,6 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
|
|||||||
return s.GetAccount(ctx, key.AccountID)
|
return s.GetAccount(ctx, key.AccountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) {
|
|
||||||
var token PersonalAccessToken
|
|
||||||
result := s.db.First(&token, "hashed_token = ?", hashedToken)
|
|
||||||
if result.Error != nil {
|
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
||||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
||||||
}
|
|
||||||
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
|
||||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
return token.ID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) {
|
|
||||||
var token PersonalAccessToken
|
|
||||||
result := s.db.First(&token, idQueryCondition, tokenID)
|
|
||||||
if result.Error != nil {
|
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
||||||
}
|
|
||||||
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
|
||||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
if token.UserID == "" {
|
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
var user User
|
|
||||||
result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID)
|
|
||||||
if result.Error != nil {
|
|
||||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATsG))
|
|
||||||
for _, pat := range user.PATsG {
|
|
||||||
user.PATs[pat.ID] = pat.Copy()
|
|
||||||
}
|
|
||||||
|
|
||||||
return &user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
|
||||||
var user User
|
var user User
|
||||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
@ -526,6 +483,23 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
|||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.NewUserNotFoundError(userID)
|
return nil, status.NewUserNotFoundError(userID)
|
||||||
}
|
}
|
||||||
|
log.WithContext(ctx).Errorf("failed to get user from the store: %s", result.Error)
|
||||||
|
return nil, status.NewGetUserFromStoreError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error) {
|
||||||
|
var user User
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
|
||||||
|
Where("personal_access_tokens.id = ?", patID).First(&user)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.NewPATNotFoundError()
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("failed to get token user from the store: %s", result.Error)
|
||||||
return nil, status.NewGetUserFromStoreError()
|
return nil, status.NewGetUserFromStoreError()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1635,6 +1609,21 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
|
||||||
|
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error) {
|
||||||
|
var pat PersonalAccessToken
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.NewPATNotFoundError()
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("failed to get pat from the store: %s", result.Error)
|
||||||
|
return nil, status.NewGetPATFromStoreError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pat, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetPATByID retrieves a personal access token by its ID and user ID.
|
// GetPATByID retrieves a personal access token by its ID and user ID.
|
||||||
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*PersonalAccessToken, error) {
|
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*PersonalAccessToken, error) {
|
||||||
var pat PersonalAccessToken
|
var pat PersonalAccessToken
|
||||||
@ -1642,10 +1631,10 @@ func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength,
|
|||||||
First(&pat, "id = ? AND user_id = ?", patID, userID)
|
First(&pat, "id = ? AND user_id = ?", patID, userID)
|
||||||
if err := result.Error; err != nil {
|
if err := result.Error; err != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "PAT not found")
|
return nil, status.NewPATNotFoundError()
|
||||||
}
|
}
|
||||||
log.WithContext(ctx).Errorf("failed to get PAT from the store: %s", err)
|
log.WithContext(ctx).Errorf("failed to get PAT from the store: %s", err)
|
||||||
return nil, status.Errorf(status.Internal, "failed to get PAT from store")
|
return nil, status.NewGetPATFromStoreError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return &pat, nil
|
return &pat, nil
|
||||||
|
@ -572,11 +572,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) {
|
|||||||
hashed := "SoMeHaShEdToKeN"
|
hashed := "SoMeHaShEdToKeN"
|
||||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||||
|
|
||||||
token, err := store.GetTokenIDByHashedToken(context.Background(), hashed)
|
pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, hashed)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, id, token)
|
require.Equal(t, id, pat.ID)
|
||||||
|
|
||||||
_, err = store.GetTokenIDByHashedToken(context.Background(), "non-existing-hash")
|
_, err = store.GetPATByHashedToken(context.Background(), LockingStrengthShare, "non-existing-hash")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
parsedErr, ok := status.FromError(err)
|
parsedErr, ok := status.FromError(err)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
@ -595,11 +595,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) {
|
|||||||
|
|
||||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||||
|
|
||||||
user, err := store.GetUserByTokenID(context.Background(), id)
|
user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, id, user.PATs[id].ID)
|
require.Equal(t, id, user.PATs[id].ID)
|
||||||
|
|
||||||
_, err = store.GetUserByTokenID(context.Background(), "non-existing-id")
|
_, err = store.GetUserByPATID(context.Background(), LockingStrengthShare, "non-existing-id")
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
parsedErr, ok := status.FromError(err)
|
parsedErr, ok := status.FromError(err)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
@ -967,9 +967,9 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) {
|
|||||||
hashed := "SoMeHaShEdToKeN"
|
hashed := "SoMeHaShEdToKeN"
|
||||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||||
|
|
||||||
token, err := store.GetTokenIDByHashedToken(context.Background(), hashed)
|
pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, hashed)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, id, token)
|
require.Equal(t, id, pat.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
||||||
@ -984,7 +984,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
|
|||||||
|
|
||||||
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
id := "9dj38s35-63fb-11ec-90d6-0242ac120003"
|
||||||
|
|
||||||
user, err := store.GetUserByTokenID(context.Background(), id)
|
user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, id, user.PATs[id].ID)
|
require.Equal(t, id, user.PATs[id].ID)
|
||||||
}
|
}
|
||||||
|
@ -146,6 +146,10 @@ func NewPATNotFoundError() error {
|
|||||||
return Errorf(NotFound, "PAT not found")
|
return Errorf(NotFound, "PAT not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewGetPATFromStoreError() error {
|
||||||
|
return Errorf(Internal, "issue getting pat from store")
|
||||||
|
}
|
||||||
|
|
||||||
func NewUnauthorizedToViewPATsError() error {
|
func NewUnauthorizedToViewPATsError() error {
|
||||||
return Errorf(PermissionDenied, "only users with admin power can view PATs")
|
return Errorf(PermissionDenied, "only users with admin power can view PATs")
|
||||||
}
|
}
|
||||||
|
@ -63,13 +63,12 @@ type Store interface {
|
|||||||
DeleteAccount(ctx context.Context, account *Account) error
|
DeleteAccount(ctx context.Context, account *Account) error
|
||||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||||
|
|
||||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error)
|
||||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||||
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
|
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
|
||||||
SaveUsers(accountID string, users map[string]*User) error
|
SaveUsers(accountID string, users map[string]*User) error
|
||||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
||||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
|
||||||
|
|
||||||
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
|
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error)
|
||||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error)
|
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error)
|
||||||
@ -125,6 +124,7 @@ type Store interface {
|
|||||||
|
|
||||||
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error)
|
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error)
|
||||||
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error)
|
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error)
|
||||||
|
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error)
|
||||||
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
|
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
|
||||||
SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error
|
SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error
|
||||||
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
|
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
|
||||||
|
@ -55,25 +55,25 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
|||||||
eventStore: &activity.InMemoryEventStore{},
|
eventStore: &activity.InMemoryEventStore{},
|
||||||
}
|
}
|
||||||
|
|
||||||
pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
|
newPAT, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error when adding PAT to user: %s", err)
|
t.Fatalf("Error when adding PAT to user: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, pat.CreatedBy, mockUserID)
|
assert.Equal(t, newPAT.CreatedBy, mockUserID)
|
||||||
|
|
||||||
tokenID, err := am.Store.GetTokenIDByHashedToken(context.Background(), pat.HashedToken)
|
pat, err := am.Store.GetPATByHashedToken(context.Background(), LockingStrengthShare, newPAT.HashedToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error when getting token ID by hashed token: %s", err)
|
t.Fatalf("Error when getting token ID by hashed token: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tokenID == "" {
|
if pat.ID == "" {
|
||||||
t.Fatal("GetTokenIDByHashedToken failed after adding PAT")
|
t.Fatal("GetTokenIDByHashedToken failed after adding PAT")
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, pat.ID, tokenID)
|
assert.Equal(t, newPAT.ID, pat.ID)
|
||||||
|
|
||||||
user, err := am.Store.GetUserByTokenID(context.Background(), tokenID)
|
user, err := am.Store.GetUserByPATID(context.Background(), LockingStrengthShare, pat.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error when getting user by token ID: %s", err)
|
t.Fatalf("Error when getting user by token ID: %s", err)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user