mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-23 23:29:15 +01:00
Add peer login expiration (#682)
This PR adds a peer login expiration logic that requires peers created by a user to re-authenticate (re-login) after a certain threshold of time (24h by default). The Account object now has a PeerLoginExpiration property that indicates the duration after which a peer's login will expire and a login will be required. Defaults to 24h. There are two new properties added to the Peer object: LastLogin that indicates the last time peer successfully used the Login gRPC endpoint and LoginExpirationEnabled that enables/disables peer login expiration. The login expiration logic applies only to peers that were created by a user and not those that were added with a setup key.
This commit is contained in:
parent
aecee361d0
commit
3fc89749c1
@ -26,11 +26,12 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
PublicCategory = "public"
|
||||
PrivateCategory = "private"
|
||||
UnknownCategory = "unknown"
|
||||
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||
PublicCategory = "public"
|
||||
PrivateCategory = "private"
|
||||
UnknownCategory = "unknown"
|
||||
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||
)
|
||||
|
||||
func cacheEntryExpiration() time.Duration {
|
||||
@ -48,6 +49,7 @@ type AccountManager interface {
|
||||
SaveUser(accountID, userID string, update *User) (*UserInfo, error)
|
||||
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
||||
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
|
||||
GetAccountByPeerID(peerID string) (*Account, error)
|
||||
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
||||
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
|
||||
AccountExists(accountId string) (*bool, error)
|
||||
@ -93,6 +95,7 @@ type AccountManager interface {
|
||||
GetDNSSettings(accountID string, userID string) (*DNSSettings, error)
|
||||
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
|
||||
GetPeer(accountID, peerID, userID string) (*Peer, error)
|
||||
UpdatePeerLastLogin(peerID string) error
|
||||
}
|
||||
|
||||
type DefaultAccountManager struct {
|
||||
@ -134,6 +137,9 @@ type Account struct {
|
||||
Routes map[string]*route.Route
|
||||
NameServerGroups map[string]*nbdns.NameServerGroup
|
||||
DNSSettings *DNSSettings
|
||||
// PeerLoginExpiration is a setting that indicates when peer login expires.
|
||||
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
|
||||
PeerLoginExpiration time.Duration
|
||||
}
|
||||
|
||||
type UserInfo struct {
|
||||
@ -484,6 +490,7 @@ func (a *Account) Copy() *Account {
|
||||
Routes: routes,
|
||||
NameServerGroups: nsGroups,
|
||||
DNSSettings: dnsSettings,
|
||||
PeerLoginExpiration: a.PeerLoginExpiration,
|
||||
}
|
||||
}
|
||||
|
||||
@ -606,6 +613,11 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountByPeerID returns account from the store by a provided peer ID
|
||||
func (am *DefaultAccountManager) GetAccountByPeerID(peerID string) (*Account, error) {
|
||||
return am.Store.GetAccountByPeerID(peerID)
|
||||
}
|
||||
|
||||
// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
|
||||
// userID doesn't have an account associated with it, one account is created
|
||||
func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) {
|
||||
@ -1085,16 +1097,17 @@ func newAccountWithId(accountId, userId, domain string) *Account {
|
||||
log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key)
|
||||
|
||||
acc := &Account{
|
||||
Id: accountId,
|
||||
SetupKeys: setupKeys,
|
||||
Network: network,
|
||||
Peers: peers,
|
||||
Users: users,
|
||||
CreatedBy: userId,
|
||||
Domain: domain,
|
||||
Routes: routes,
|
||||
NameServerGroups: nameServersGroups,
|
||||
DNSSettings: dnsSettings,
|
||||
Id: accountId,
|
||||
SetupKeys: setupKeys,
|
||||
Network: network,
|
||||
Peers: peers,
|
||||
Users: users,
|
||||
CreatedBy: userId,
|
||||
Domain: domain,
|
||||
Routes: routes,
|
||||
NameServerGroups: nameServersGroups,
|
||||
DNSSettings: dnsSettings,
|
||||
PeerLoginExpiration: DefaultPeerLoginExpiration,
|
||||
}
|
||||
|
||||
addAllGroup(acc)
|
||||
|
@ -81,6 +81,11 @@ func restore(file string) (*FileStore, error) {
|
||||
store.PeerID2AccountID = make(map[string]string)
|
||||
|
||||
for accountID, account := range store.Accounts {
|
||||
|
||||
if account.PeerLoginExpiration.Seconds() == 0 {
|
||||
account.PeerLoginExpiration = DefaultPeerLoginExpiration
|
||||
}
|
||||
|
||||
for setupKeyId := range account.SetupKeys {
|
||||
store.SetupKeyID2AccountID[strings.ToUpper(setupKeyId)] = accountID
|
||||
}
|
||||
|
@ -132,6 +132,15 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
return msg
|
||||
}
|
||||
|
||||
account, err := s.accountManager.GetAccountByPeerID(peer.ID)
|
||||
if err != nil {
|
||||
return status.Error(codes.Internal, "internal server error")
|
||||
}
|
||||
expired, left := peer.LoginExpired(account.PeerLoginExpiration)
|
||||
if peer.UserID != "" && expired {
|
||||
return status.Errorf(codes.PermissionDenied, "peer login has expired %v ago. Please log in once more", left)
|
||||
}
|
||||
|
||||
syncReq := &proto.SyncRequest{}
|
||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, syncReq)
|
||||
if err != nil {
|
||||
@ -196,29 +205,37 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
|
||||
if s.jwtMiddleware == nil {
|
||||
return "", status.Error(codes.Internal, "no jwt middleware set")
|
||||
}
|
||||
|
||||
token, err := s.jwtMiddleware.ValidateAndParse(jwtToken)
|
||||
if err != nil {
|
||||
return "", status.Errorf(codes.Internal, "invalid jwt token, err: %v", err)
|
||||
}
|
||||
claims := s.jwtClaimsExtractor.FromToken(token)
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
_, _, err = s.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
||||
}
|
||||
|
||||
return claims.UserId, nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Peer, error) {
|
||||
var (
|
||||
reqSetupKey string
|
||||
userID string
|
||||
err error
|
||||
)
|
||||
|
||||
if req.GetJwtToken() != "" {
|
||||
log.Debugln("using jwt token to register peer")
|
||||
|
||||
if s.jwtMiddleware == nil {
|
||||
return nil, status.Error(codes.Internal, "no jwt middleware set")
|
||||
}
|
||||
|
||||
token, err := s.jwtMiddleware.ValidateAndParse(req.GetJwtToken())
|
||||
userID, err = s.validateToken(req.JwtToken)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err)
|
||||
}
|
||||
claims := s.jwtClaimsExtractor.FromToken(token)
|
||||
userID = claims.UserId
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
_, _, err = s.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
log.Debugln("using setup key to register peer")
|
||||
@ -354,6 +371,29 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
}
|
||||
}
|
||||
|
||||
// check if peer login has expired
|
||||
account, err := s.accountManager.GetAccountByPeerID(peer.ID)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal server error")
|
||||
}
|
||||
expired, left := peer.LoginExpired(account.PeerLoginExpiration)
|
||||
if peer.UserID != "" && expired {
|
||||
// it might be that peer expired but user has logged in already, check token then
|
||||
if loginReq.GetJwtToken() == "" {
|
||||
return nil, status.Errorf(codes.PermissionDenied,
|
||||
"peer login has expired %v ago. Please log in once more", left)
|
||||
}
|
||||
_, err = s.validateToken(loginReq.GetJwtToken())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.accountManager.UpdatePeerLastLogin(peer.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var sshKey []byte
|
||||
if loginReq.GetPeerKeys() != nil {
|
||||
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
|
||||
|
@ -67,6 +67,8 @@ type MockAccountManager struct {
|
||||
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
|
||||
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
|
||||
GetPeerFunc func(accountID, peerID, userID string) (*server.Peer, error)
|
||||
GetAccountByPeerIDFunc func(peerID string) (*server.Account, error)
|
||||
UpdatePeerLastLoginFunc func(peerID string) error
|
||||
}
|
||||
|
||||
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
||||
@ -526,3 +528,19 @@ func (am *MockAccountManager) GetPeer(accountID, peerID, userID string) (*server
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeer is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountByPeerID mocks GetAccountByPeerID of the AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountByPeerID(peerID string) (*server.Account, error) {
|
||||
if am.GetAccountByPeerIDFunc != nil {
|
||||
return am.GetAccountByPeerIDFunc(peerID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByPeerID is not implemented")
|
||||
}
|
||||
|
||||
// UpdatePeerLastLogin mocks UpdatePeerLastLogin of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdatePeerLastLogin(peerID string) error {
|
||||
if am.UpdatePeerLastLoginFunc != nil {
|
||||
return am.UpdatePeerLastLoginFunc(peerID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method UpdatePeerLastLogin is not implemented")
|
||||
}
|
||||
|
@ -58,27 +58,44 @@ type Peer struct {
|
||||
UserID string
|
||||
// SSHKey is a public SSH key of the peer
|
||||
SSHKey string
|
||||
// SSHEnabled indicated whether SSH server is enabled on the peer
|
||||
// SSHEnabled indicates whether SSH server is enabled on the peer
|
||||
SSHEnabled bool
|
||||
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
|
||||
// Works with LastLogin
|
||||
LoginExpirationEnabled bool
|
||||
// LastLogin the time when peer performed last login operation
|
||||
LastLogin time.Time
|
||||
}
|
||||
|
||||
// Copy copies Peer object
|
||||
func (p *Peer) Copy() *Peer {
|
||||
return &Peer{
|
||||
ID: p.ID,
|
||||
Key: p.Key,
|
||||
SetupKey: p.SetupKey,
|
||||
IP: p.IP,
|
||||
Meta: p.Meta,
|
||||
Name: p.Name,
|
||||
Status: p.Status,
|
||||
UserID: p.UserID,
|
||||
SSHKey: p.SSHKey,
|
||||
SSHEnabled: p.SSHEnabled,
|
||||
DNSLabel: p.DNSLabel,
|
||||
ID: p.ID,
|
||||
Key: p.Key,
|
||||
SetupKey: p.SetupKey,
|
||||
IP: p.IP,
|
||||
Meta: p.Meta,
|
||||
Name: p.Name,
|
||||
Status: p.Status,
|
||||
UserID: p.UserID,
|
||||
SSHKey: p.SSHKey,
|
||||
SSHEnabled: p.SSHEnabled,
|
||||
DNSLabel: p.DNSLabel,
|
||||
LoginExpirationEnabled: p.LoginExpirationEnabled,
|
||||
LastLogin: p.LastLogin,
|
||||
}
|
||||
}
|
||||
|
||||
// LoginExpired indicates whether peer's login has expired or not.
|
||||
// If Peer.LastLogin plus the expiresIn duration has happened already then login has expired.
|
||||
// Return true if login has expired, false otherwise and time left to expiration (negative when expired).
|
||||
func (p *Peer) LoginExpired(expiresIn time.Duration) (bool, time.Duration) {
|
||||
expiresAt := p.LastLogin.Add(expiresIn)
|
||||
now := time.Now()
|
||||
left := expiresAt.Sub(now)
|
||||
return p.LoginExpirationEnabled && (left <= 0), left
|
||||
}
|
||||
|
||||
// FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain
|
||||
func (p *Peer) FQDN(dnsDomain string) string {
|
||||
if dnsDomain == "" {
|
||||
@ -100,7 +117,7 @@ func (p *PeerStatus) Copy() *PeerStatus {
|
||||
}
|
||||
}
|
||||
|
||||
// GetPeer looks up peer by its public WireGuard key
|
||||
// GetPeerByKey looks up peer by its public WireGuard key
|
||||
func (am *DefaultAccountManager) GetPeerByKey(peerPubKey string) (*Peer, error) {
|
||||
|
||||
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
||||
@ -436,17 +453,19 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
|
||||
}
|
||||
|
||||
newPeer := &Peer{
|
||||
ID: xid.New().String(),
|
||||
Key: peer.Key,
|
||||
SetupKey: upperKey,
|
||||
IP: nextIp,
|
||||
Meta: peer.Meta,
|
||||
Name: peer.Name,
|
||||
DNSLabel: newLabel,
|
||||
UserID: userID,
|
||||
Status: &PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
SSHEnabled: false,
|
||||
SSHKey: peer.SSHKey,
|
||||
ID: xid.New().String(),
|
||||
Key: peer.Key,
|
||||
SetupKey: upperKey,
|
||||
IP: nextIp,
|
||||
Meta: peer.Meta,
|
||||
Name: peer.Name,
|
||||
DNSLabel: newLabel,
|
||||
UserID: userID,
|
||||
Status: &PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
SSHEnabled: false,
|
||||
SSHKey: peer.SSHKey,
|
||||
LastLogin: time.Now(),
|
||||
LoginExpirationEnabled: false,
|
||||
}
|
||||
|
||||
// add peer to 'All' group
|
||||
@ -491,6 +510,38 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
|
||||
return newPeer, nil
|
||||
}
|
||||
|
||||
// UpdatePeerLastLogin sets Peer.LastLogin to the current timestamp.
|
||||
func (am *DefaultAccountManager) UpdatePeerLastLogin(peerID string) error {
|
||||
account, err := am.Store.GetAccountByPeerID(peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||
defer unlock()
|
||||
|
||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||
account, err = am.Store.GetAccount(account.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
return status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
peer.LastLogin = time.Now()
|
||||
account.UpdatePeer(peer)
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePeerSSHKey updates peer's public SSH key
|
||||
func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string) error {
|
||||
|
||||
|
@ -3,11 +3,57 @@ package server
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
func TestPeer_LoginExpired(t *testing.T) {
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
expirationEnbaled bool
|
||||
lastLogin time.Time
|
||||
expiresIn time.Duration
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Peer Login Expiration Disabled. Peer Login Should Not Expire",
|
||||
expirationEnbaled: false,
|
||||
lastLogin: time.Now().Add(-25 * time.Hour),
|
||||
expiresIn: time.Hour,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Peer Login Should Expire",
|
||||
expirationEnbaled: true,
|
||||
lastLogin: time.Now().Add(-25 * time.Hour),
|
||||
expiresIn: time.Hour,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Peer Login Should Not Expire",
|
||||
expirationEnbaled: true,
|
||||
lastLogin: time.Now(),
|
||||
expiresIn: time.Hour,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range tt {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
peer := &Peer{
|
||||
LoginExpirationEnabled: c.expirationEnbaled,
|
||||
LastLogin: c.lastLogin,
|
||||
}
|
||||
|
||||
expired, _ := peer.LoginExpired(c.expiresIn)
|
||||
assert.Equal(t, expired, c.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user