diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 265b6d6e6..a1be5bbb4 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1599,3 +1599,95 @@ func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStre return nil } + +func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networks.Network, error) { + var networks []*networks.Network + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&networks, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get networks from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get networks from store") + } + + return networks, nil +} + +func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networks.Network, error) { + var network *networks.Network + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&network, accountAndIDQueryCondition, accountID, networkID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewNetworkNotFoundError(networkID) + } + + log.WithContext(ctx).Errorf("failed to get network from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network from store") + } + + return network, nil +} + +func (s *SqlStore) SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networks.Network) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(network) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save network to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save network to store") + } + + return nil +} + +func (s *SqlStore) DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&networks.Network{}, accountAndIDQueryCondition, accountID, networkID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete network from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete network from store") + } + + if result.RowsAffected == 0 { + return status.NewNetworkNotFoundError(networkID) + } + + return nil +} + +func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*networks.NetworkRouter, error) { + //TODO implement me + panic("implement me") +} + +func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*networks.NetworkRouter, error) { + //TODO implement me + panic("implement me") +} + +func (s *SqlStore) SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *networks.NetworkRouter) error { + //TODO implement me + panic("implement me") +} + +func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error { + //TODO implement me + panic("implement me") +} + +func (s *SqlStore) GetAccountNetworkResourceByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networks.NetworkResource, error) { + //TODO implement me + panic("implement me") +} + +func (s *SqlStore) GetAccountNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*networks.NetworkResource, error) { + //TODO implement me + panic("implement me") +} + +func (s *SqlStore) SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *networks.NetworkResource) error { + //TODO implement me + panic("implement me") +} + +func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error { + //TODO implement me + panic("implement me") +} diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 6064b019f..4e5940cfd 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -16,6 +16,7 @@ import ( "github.com/google/uuid" nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/posture" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -2051,3 +2052,120 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) { require.Error(t, err) require.Nil(t, nsGroup) } + +func TestSqlStore_GetAccountNetworks(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "retrieve networks by existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + + { + name: "retrieve networks by non-existing account ID", + accountID: "non-existent", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + networks, err := store.GetAccountNetworks(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, networks, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetNetworkByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + networkID string + expectError bool + }{ + { + name: "retrieve existing network ID", + networkID: "ct286bi7qv930dsrrug0", + expectError: false, + }, + { + name: "retrieve non-existing network ID", + networkID: "non-existing", + expectError: true, + }, + { + name: "retrieve network with empty ID", + networkID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + network, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, tt.networkID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, network) + } else { + require.NoError(t, err) + require.NotNil(t, network) + require.Equal(t, tt.networkID, network.ID) + } + }) + } +} + +func TestSqlStore_SaveNetwork(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + network := &networks.Network{ + ID: "net-id", + AccountID: accountID, + Name: "net", + } + + err = store.SaveNetwork(context.Background(), LockingStrengthUpdate, network) + require.NoError(t, err) + + savedNet, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, network.ID) + require.NoError(t, err) + require.Equal(t, network, savedNet) +} + +func TestSqlStore_DeleteNetwork(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + networkID := "ct286bi7qv930dsrrug0" + + err = store.DeleteNetwork(context.Background(), LockingStrengthUpdate, accountID, networkID) + require.NoError(t, err) + + network, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, networkID) + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, sErr.Type()) + require.Nil(t, network) +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 59f436f5b..33df4ed33 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -154,3 +154,8 @@ func NewPolicyNotFoundError(policyID string) error { func NewNameServerGroupNotFoundError(nsGroupID string) error { return Errorf(NotFound, "nameserver group: %s not found", nsGroupID) } + +// NewNetworkNotFoundError creates a new Error with NotFound type for a missing network. +func NewNetworkNotFoundError(networkID string) error { + return Errorf(NotFound, "network: %s not found", networkID) +} diff --git a/management/server/store.go b/management/server/store.go index b16ad8a1a..1a8168de7 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/netbirdio/netbird/management/server/networks" log "github.com/sirupsen/logrus" "gorm.io/driver/sqlite" "gorm.io/gorm" @@ -140,6 +141,21 @@ type Store interface { // This is also a method of metrics.DataSource interface. GetStoreEngine() StoreEngine ExecuteInTransaction(ctx context.Context, f func(store Store) error) error + + GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networks.Network, error) + GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networks.Network, error) + SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networks.Network) error + DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error + + GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*networks.NetworkRouter, error) + GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*networks.NetworkRouter, error) + SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *networks.NetworkRouter) error + DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error + + GetAccountNetworkResourceByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networks.NetworkResource, error) + GetAccountNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*networks.NetworkResource, error) + SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *networks.NetworkResource) error + DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error } type StoreEngine string diff --git a/management/server/testdata/store.sql b/management/server/testdata/store.sql index 168973cad..f1111da19 100644 --- a/management/server/testdata/store.sql +++ b/management/server/testdata/store.sql @@ -12,6 +12,7 @@ CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE TABLE `networks` (`id` text,`account_id` text,`name` text,`description` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_networks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); CREATE INDEX `idx_peers_key` ON `peers`(`key`); @@ -24,6 +25,8 @@ CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); +CREATE INDEX `idx_networks_id` ON `networks`(`id`); +CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); @@ -34,3 +37,4 @@ INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003' INSERT INTO installations VALUES(1,''); INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]'); INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL); +INSERT INTO networks VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Test Network','Test Network');