mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-28 21:51:40 +02:00
Merge branch 'groups-get-account-refactoring' into policy-get-account-refactoring
This commit is contained in:
commit
147971fdfe
@ -2294,12 +2294,12 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, status.NewGetAccountError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account)
|
peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account)
|
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account)
|
||||||
@ -2318,7 +2318,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
|
|||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return status.NewGetAccountError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
|
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
|
||||||
|
@ -180,6 +180,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
|||||||
|
|
||||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
|
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||||
return mapError(ctx, err)
|
return mapError(ctx, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -207,6 +208,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
|||||||
|
|
||||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||||
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||||
|
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
// condition when there are some updates
|
// condition when there are some updates
|
||||||
@ -260,10 +262,15 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p
|
|||||||
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
||||||
|
}
|
||||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||||
s.secretsManager.CancelRefresh(peer.ID)
|
s.secretsManager.CancelRefresh(peer.ID)
|
||||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||||
|
|
||||||
|
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||||
|
@ -110,14 +110,16 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
|||||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error {
|
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error {
|
||||||
peer, err := account.FindPeerByPubKey(peerPubKey)
|
peer, err := account.FindPeerByPubKey(peerPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to find peer by pub key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account)
|
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to update peer status and location: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debugf("mark peer %s connected: %t", peer.ID, connected)
|
||||||
|
|
||||||
if peer.AddedWithSSOLogin() {
|
if peer.AddedWithSSOLogin() {
|
||||||
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
||||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||||
@ -168,7 +170,7 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
|
|||||||
|
|
||||||
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, fmt.Errorf("failed to save peer status: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return oldStatus.LoginExpired, nil
|
return oldStatus.LoginExpired, nil
|
||||||
@ -590,7 +592,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
|
|
||||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
|
return nil, nil, nil, status.NewGetAccountError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
allGroup, err := account.GetGroupAll()
|
allGroup, err := account.GetGroupAll()
|
||||||
@ -652,7 +654,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
|||||||
if peer.UserID != "" {
|
if peer.UserID != "" {
|
||||||
user, err := account.FindUser(peer.UserID)
|
user, err := account.FindUser(peer.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, fmt.Errorf("failed to get user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = checkIfPeerOwnerIsBlocked(peer, user)
|
err = checkIfPeerOwnerIsBlocked(peer, user)
|
||||||
@ -669,7 +671,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
|||||||
if updated {
|
if updated {
|
||||||
err = am.Store.SavePeer(ctx, account.Id, peer)
|
err = am.Store.SavePeer(ctx, account.Id, peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if sync.UpdateAccountPeers {
|
if sync.UpdateAccountPeers {
|
||||||
@ -679,7 +681,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
|||||||
|
|
||||||
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, fmt.Errorf("failed to validate peer: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var postureChecks []*posture.Checks
|
var postureChecks []*posture.Checks
|
||||||
@ -697,7 +699,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
|||||||
|
|
||||||
validPeersMap, err := am.GetValidatedPeers(account)
|
validPeersMap, err := am.GetValidatedPeers(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID)
|
postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID)
|
||||||
|
@ -14,11 +14,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
route2 "github.com/netbirdio/netbird/route"
|
route2 "github.com/netbirdio/netbird/route"
|
||||||
|
|
||||||
@ -1293,3 +1292,275 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
|
|||||||
err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID)
|
err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetGroupsByIDs(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
groupIDs []string
|
||||||
|
expectedCount int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve existing groups by existing IDs",
|
||||||
|
groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"},
|
||||||
|
expectedCount: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty group IDs list",
|
||||||
|
groupIDs: []string{},
|
||||||
|
expectedCount: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-existing group IDs",
|
||||||
|
groupIDs: []string{"nonexistent1", "nonexistent2"},
|
||||||
|
expectedCount: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed existing and non-existing group IDs",
|
||||||
|
groupIDs: []string{"cfefqs706sqkneg59g4g", "nonexistent"},
|
||||||
|
expectedCount: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
groups, err := store.GetGroupsByIDs(context.Background(), LockingStrengthShare, accountID, tt.groupIDs)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, groups, tt.expectedCount)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_SaveGroup(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
group := &nbgroup.Group{
|
||||||
|
ID: "group-id",
|
||||||
|
AccountID: accountID,
|
||||||
|
Issued: "api",
|
||||||
|
Peers: []string{"peer1", "peer2"},
|
||||||
|
}
|
||||||
|
err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, savedGroup, group)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_SaveGroups(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
groups := []*nbgroup.Group{
|
||||||
|
{
|
||||||
|
ID: "group-1",
|
||||||
|
AccountID: accountID,
|
||||||
|
Issued: "api",
|
||||||
|
Peers: []string{"peer1", "peer2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "group-2",
|
||||||
|
AccountID: accountID,
|
||||||
|
Issued: "integration",
|
||||||
|
Peers: []string{"peer3", "peer4"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_DeleteGroup(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
groupID string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "delete existing group",
|
||||||
|
groupID: "cfefqs706sqkneg59g4g",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete non-existing group",
|
||||||
|
groupID: "non-existing-group-id",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete with empty group ID",
|
||||||
|
groupID: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := store.DeleteGroup(context.Background(), LockingStrengthUpdate, accountID, tt.groupID)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, tt.groupID)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, group)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_DeleteGroups(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
groupIDs []string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "delete multiple existing groups",
|
||||||
|
groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete non-existing groups",
|
||||||
|
groupIDs: []string{"non-existing-id-1", "non-existing-id-2"},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete with empty group IDs list",
|
||||||
|
groupIDs: []string{},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := store.DeleteGroups(context.Background(), LockingStrengthUpdate, accountID, tt.groupIDs)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, groupID := range tt.groupIDs {
|
||||||
|
group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetPeerByID(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
peerID string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve existing peer",
|
||||||
|
peerID: "cfefqs706sqkneg59g4g",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve non-existing peer",
|
||||||
|
peerID: "non-existing",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve with empty peer ID",
|
||||||
|
peerID: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, tt.peerID)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
require.Nil(t, peer)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, peer)
|
||||||
|
require.Equal(t, tt.peerID, peer.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetPeersByIDs(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
peerIDs []string
|
||||||
|
expectedCount int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve existing peers by existing IDs",
|
||||||
|
peerIDs: []string{"cfefqs706sqkneg59g4g", "cfeg6sf06sqkneg59g50"},
|
||||||
|
expectedCount: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty peer IDs list",
|
||||||
|
peerIDs: []string{},
|
||||||
|
expectedCount: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-existing peer IDs",
|
||||||
|
peerIDs: []string{"nonexistent1", "nonexistent2"},
|
||||||
|
expectedCount: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed existing and non-existing peer IDs",
|
||||||
|
peerIDs: []string{"cfeg6sf06sqkneg59g50", "nonexistent"},
|
||||||
|
expectedCount: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
peers, err := store.GetPeersByIDs(context.Background(), LockingStrengthShare, accountID, tt.peerIDs)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, peers, tt.expectedCount)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -130,6 +130,11 @@ func NewInvalidKeyIDError() error {
|
|||||||
return Errorf(InvalidArgument, "invalid key ID")
|
return Errorf(InvalidArgument, "invalid key ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewGetAccountError creates a new Error with Internal type for an issue getting account
|
||||||
|
func NewGetAccountError(err error) error {
|
||||||
|
return Errorf(Internal, "error getting account: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
// NewGroupNotFoundError creates a new Error with NotFound type for a missing group
|
// NewGroupNotFoundError creates a new Error with NotFound type for a missing group
|
||||||
func NewGroupNotFoundError(groupID string) error {
|
func NewGroupNotFoundError(groupID string) error {
|
||||||
return Errorf(NotFound, "group: %s not found", groupID)
|
return Errorf(NotFound, "group: %s not found", groupID)
|
||||||
|
@ -96,9 +96,12 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
|
|||||||
if channel, ok := p.peerChannels[peerID]; ok {
|
if channel, ok := p.peerChannels[peerID]; ok {
|
||||||
delete(p.peerChannels, peerID)
|
delete(p.peerChannels, peerID)
|
||||||
close(channel)
|
close(channel)
|
||||||
|
|
||||||
|
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
|
log.WithContext(ctx).Debugf("closing updates channel: peer %s has no channel", peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseChannels closes updates channel for each given peer
|
// CloseChannels closes updates channel for each given peer
|
||||||
|
@ -9,14 +9,16 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
"github.com/netbirdio/netbird/management/server/integration_reference"
|
"github.com/netbirdio/netbird/management/server/integration_reference"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -1105,6 +1107,9 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
|||||||
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error {
|
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error {
|
||||||
var peerIDs []string
|
var peerIDs []string
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
|
// nolint:staticcheck
|
||||||
|
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key)
|
||||||
|
|
||||||
if peer.Status.LoginExpired {
|
if peer.Status.LoginExpired {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -1112,8 +1117,11 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
|||||||
peer.MarkLoginExpired(true)
|
peer.MarkLoginExpired(true)
|
||||||
account.UpdatePeer(peer)
|
account.UpdatePeer(peer)
|
||||||
if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil {
|
if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil {
|
||||||
return err
|
return fmt.Errorf("failed saving peer status for peer %s: %s", peer.ID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Tracef("mark peer %s login expired", peer.ID)
|
||||||
|
|
||||||
am.StoreEvent(
|
am.StoreEvent(
|
||||||
ctx,
|
ctx,
|
||||||
peer.UserID, peer.ID, account.Id,
|
peer.UserID, peer.ID, account.Id,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user