diff --git a/management/server/route.go b/management/server/route.go index 0d5b5cb1d..0d2438ca3 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -4,9 +4,9 @@ import ( "context" "fmt" "net/netip" + "slices" "unicode/utf8" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/rs/xid" "github.com/netbirdio/netbird/management/server/store" @@ -238,12 +238,12 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI var oldRouteAffectsPeers bool var newRouteAffectsPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil { return err } - oldRoute, err = transaction.GetRouteByID(ctx, LockingStrengthUpdate, accountID, string(routeToSave.ID)) + oldRoute, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeToSave.ID)) if err != nil { return err } @@ -272,7 +272,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) if oldRouteAffectsPeers || newRouteAffectsPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -283,7 +283,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -306,7 +306,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } @@ -350,7 +350,7 @@ func validateRoute(ctx context.Context, transaction store.Store, accountID strin } if routeToSave.Peer != "" { - peer, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, routeToSave.Peer) + peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, routeToSave.Peer) if err != nil { return err } @@ -398,7 +398,7 @@ func validateRouteProperties(routeToSave *route.Route) error { } // validateRouteGroups validates the route groups and returns the validated groups map. -func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*nbgroup.Group, error) { +func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) { groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups) groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate) if err != nil { diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 4893decb8..7efc68745 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -3128,7 +3128,7 @@ func TestSqlStore_DeletePAT(t *testing.T) { } func TestSqlStore_GetAccountRoutes(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -3164,7 +3164,7 @@ func TestSqlStore_GetAccountRoutes(t *testing.T) { } func TestSqlStore_GetRouteByID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -3210,7 +3210,7 @@ func TestSqlStore_GetRouteByID(t *testing.T) { } func TestSqlStore_SaveRoute(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -3239,7 +3239,7 @@ func TestSqlStore_SaveRoute(t *testing.T) { } func TestSqlStore_DeleteRoute(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err)