From 79f94dd0bb38b738f4080d5b052bcb37f72a7101 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 2 Jan 2025 16:49:23 +0300 Subject: [PATCH] Refactor pat to support mysql Signed-off-by: bcmmbaga --- management/server/account.go | 2 +- .../server/http/handlers/users/pat_handler.go | 9 ++----- .../server/http/middleware/auth_middleware.go | 2 +- .../server/types/personal_access_token.go | 24 +++++++++++++++---- management/server/types/user.go | 14 ++--------- 5 files changed, 26 insertions(+), 25 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index a37a9d2fe..82dd0e9dd 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1134,7 +1134,7 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string return fmt.Errorf("token not found") } - pat.LastUsed = time.Now().UTC() + pat.LastUsed = util.ToPtr(time.Now().UTC()) return am.Store.SaveAccount(ctx, account) } diff --git a/management/server/http/handlers/users/pat_handler.go b/management/server/http/handlers/users/pat_handler.go index 197785b34..a73b83a37 100644 --- a/management/server/http/handlers/users/pat_handler.go +++ b/management/server/http/handlers/users/pat_handler.go @@ -3,7 +3,6 @@ package users import ( "encoding/json" "net/http" - "time" "github.com/gorilla/mux" @@ -166,17 +165,13 @@ func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) { } func toPATResponse(pat *types.PersonalAccessToken) *api.PersonalAccessToken { - var lastUsed *time.Time - if !pat.LastUsed.IsZero() { - lastUsed = &pat.LastUsed - } return &api.PersonalAccessToken{ CreatedAt: pat.CreatedAt, CreatedBy: pat.CreatedBy, Name: pat.Name, - ExpirationDate: pat.ExpirationDate, + ExpirationDate: pat.ExpirationTime(), Id: pat.ID, - LastUsed: lastUsed, + LastUsed: pat.LastUsed, } } diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 0a54cbaed..953b55483 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -161,7 +161,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ if err != nil { return fmt.Errorf("invalid Token: %w", err) } - if time.Now().After(pat.ExpirationDate) { + if time.Now().After(pat.ExpirationTime()) { return fmt.Errorf("token expired") } diff --git a/management/server/types/personal_access_token.go b/management/server/types/personal_access_token.go index 1bf225856..d30735a77 100644 --- a/management/server/types/personal_access_token.go +++ b/management/server/types/personal_access_token.go @@ -8,6 +8,7 @@ import ( "time" b "github.com/hashicorp/go-secure-stdlib/base62" + "github.com/netbirdio/netbird/management/server/util" "github.com/rs/xid" "github.com/netbirdio/netbird/base62" @@ -31,11 +32,11 @@ type PersonalAccessToken struct { UserID string `gorm:"index"` Name string HashedToken string - ExpirationDate time.Time + ExpirationDate *time.Time // scope could be added in future CreatedBy string CreatedAt time.Time - LastUsed time.Time + LastUsed *time.Time } func (t *PersonalAccessToken) Copy() *PersonalAccessToken { @@ -50,6 +51,22 @@ func (t *PersonalAccessToken) Copy() *PersonalAccessToken { } } +// ExpirationTime returns the expiration time of the token. +func (t *PersonalAccessToken) ExpirationTime() time.Time { + if t.ExpirationDate != nil { + return *t.ExpirationDate + } + return time.Time{} +} + +// LastUsedTime returns the last time the token was used. +func (t *PersonalAccessToken) LastUsedTime() time.Time { + if t.LastUsed != nil { + return *t.LastUsed + } + return time.Time{} +} + // PersonalAccessTokenGenerated holds the new PersonalAccessToken and the plain text version of it type PersonalAccessTokenGenerated struct { PlainToken string @@ -69,10 +86,9 @@ func CreateNewPAT(name string, expirationInDays int, createdBy string) (*Persona ID: xid.New().String(), Name: name, HashedToken: hashedToken, - ExpirationDate: currentTime.AddDate(0, 0, expirationInDays), + ExpirationDate: util.ToPtr(currentTime.AddDate(0, 0, expirationInDays)), CreatedBy: createdBy, CreatedAt: currentTime, - LastUsed: time.Time{}, }, PlainToken: plainToken, }, nil diff --git a/management/server/types/user.go b/management/server/types/user.go index f304acf0e..20be53167 100644 --- a/management/server/types/user.go +++ b/management/server/types/user.go @@ -142,11 +142,6 @@ func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo } if userData == nil { - var lastLogin time.Time - if u.LastLogin != nil { - lastLogin = *u.LastLogin - } - return &UserInfo{ ID: u.Id, Email: "", @@ -156,7 +151,7 @@ func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo Status: string(UserStatusActive), IsServiceUser: u.IsServiceUser, IsBlocked: u.Blocked, - LastLogin: lastLogin, + LastLogin: u.LastLoginTime(), Issued: u.Issued, Permissions: UserPermissions{ DashboardView: dashboardViewPermissions, @@ -172,11 +167,6 @@ func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo userStatus = UserStatusInvited } - lastLogin := time.Time{} - if u.LastLogin != nil { - lastLogin = *u.LastLogin - } - return &UserInfo{ ID: u.Id, Email: userData.Email, @@ -186,7 +176,7 @@ func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo Status: string(userStatus), IsServiceUser: u.IsServiceUser, IsBlocked: u.Blocked, - LastLogin: lastLogin, + LastLogin: u.LastLoginTime(), Issued: u.Issued, Permissions: UserPermissions{ DashboardView: dashboardViewPermissions,