diff --git a/client/cmd/system.go b/client/cmd/system.go index 83ce8d215..f63432401 100644 --- a/client/cmd/system.go +++ b/client/cmd/system.go @@ -38,5 +38,5 @@ func init() { upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false, "Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+ - "This overrides any policies received from the management service.") + "This overrides any policies received from the management service.") } diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 63bad689e..742294cdf 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -15,7 +15,7 @@ import ( // MockManager is the mock instance of a route manager type MockManager struct { ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap) - UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error + UpdateRoutesFunc func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error TriggerSelectionFunc func(haMap route.HAMap) GetRouteSelectorFunc func() *routeselector.RouteSelector GetClientRoutesFunc func() route.HAMap diff --git a/client/system/info.go b/client/system/info.go index a0a5fe8b3..aff10ece3 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -59,16 +59,16 @@ type Info struct { Environment Environment Files []File // for posture checks - RosenpassEnabled bool - RosenpassPermissive bool - ServerSSHAllowed bool + RosenpassEnabled bool + RosenpassPermissive bool + ServerSSHAllowed bool - DisableClientRoutes bool - DisableServerRoutes bool - DisableDNS bool - DisableFirewall bool - BlockLANAccess bool - BlockInbound bool + DisableClientRoutes bool + DisableServerRoutes bool + DisableDNS bool + DisableFirewall bool + BlockLANAccess bool + BlockInbound bool LazyConnectionEnabled bool } diff --git a/management/server/group.go b/management/server/group.go index c26a0cfc1..130a67145 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -664,15 +664,6 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac return false, nil } -func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupIDs []string) bool { - for _, groupID := range groupIDs { - if group, exists := account.Groups[groupID]; exists && group.HasPeers() { - return true - } - } - return false -} - // anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources. func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs) diff --git a/management/server/route.go b/management/server/route.go index 02755a708..32ff39977 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -4,19 +4,19 @@ import ( "context" "fmt" "net/netip" + "slices" "unicode/utf8" "github.com/rs/xid" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -30,13 +30,19 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, return nil, status.NewPermissionDeniedError() } - return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, string(routeID), accountID) + return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, accountID, string(routeID)) } // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. -func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *types.Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { +func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction store.Store, accountID string, checkRoute *route.Route, groupsMap map[string]*types.Group) error { // routes can have both peer and peer_groups - routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) + prefix := checkRoute.Network + domains := checkRoute.Domains + + routesWithPrefix, err := getRoutesByPrefixOrDomains(ctx, transaction, accountID, prefix, domains) + if err != nil { + return err + } // lets remember all the peers and the peer groups from routesWithPrefix seenPeers := make(map[string]bool) @@ -45,18 +51,24 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account for _, prefixRoute := range routesWithPrefix { // we skip route(s) with the same network ID as we want to allow updating of the existing route // when creating a new route routeID is newly generated so nothing will be skipped - if routeID == prefixRoute.ID { + if checkRoute.ID == prefixRoute.ID { continue } if prefixRoute.Peer != "" { seenPeers[string(prefixRoute.ID)] = true } + + peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, prefixRoute.PeerGroups) + if err != nil { + return err + } + for _, groupID := range prefixRoute.PeerGroups { seenPeerGroups[groupID] = true - group := account.GetGroup(groupID) - if group == nil { + group, ok := peerGroupsMap[groupID] + if !ok || group == nil { return status.Errorf( status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist", getRouteDescriptor(prefix, domains), groupID, @@ -69,12 +81,13 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account } } - if peerID != "" { + if peerID := checkRoute.Peer; peerID != "" { // check that peerID exists and is not in any route as single peer or part of the group - peer := account.GetPeer(peerID) - if peer == nil { + _, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthShare, accountID, peerID) + if err != nil { return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } + if _, ok := seenPeers[peerID]; ok { return status.Errorf(status.AlreadyExists, "failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID) @@ -82,9 +95,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account } // check that peerGroupIDs are not in any route peerGroups list - for _, groupID := range peerGroupIDs { - group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again. - + for _, groupID := range checkRoute.PeerGroups { + group := groupsMap[groupID] // we validated the group existence before entering this function, no need to check again. if _, ok := seenPeerGroups[groupID]; ok { return status.Errorf( status.AlreadyExists, "failed to add route with %s - peer group %s already has this route", @@ -92,12 +104,18 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account } // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix + peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, group.Peers) + if err != nil { + return err + } + for _, id := range group.Peers { if _, ok := seenPeers[id]; ok { - peer := account.GetPeer(id) - if peer == nil { - return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) + peer, ok := peersMap[id] + if !ok || peer == nil { + return status.Errorf(status.InvalidArgument, "peer with ID %s not found", id) } + return status.Errorf(status.AlreadyExists, "failed to add route with %s - peer %s from the group %s already has this route", getRouteDescriptor(prefix, domains), peer.Name, group.Name) @@ -128,97 +146,58 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, status.NewPermissionDeniedError() } - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, err - } - if len(domains) > 0 && prefix.IsValid() { return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") } - if len(domains) == 0 && !prefix.IsValid() { - return nil, status.Errorf(status.InvalidArgument, "invalid Prefix") - } + var newRoute *route.Route + var updateAccountPeers bool - if len(domains) > 0 { - prefix = getPlaceholderIP() - } - - if peerID != "" && len(peerGroupIDs) != 0 { - return nil, status.Errorf( - status.InvalidArgument, - "peer with ID %s and peers group %s should not be provided at the same time", - peerID, peerGroupIDs) - } - - var newRoute route.Route - newRoute.ID = route.ID(xid.New().String()) - - if len(peerGroupIDs) > 0 { - err = validateGroups(peerGroupIDs, account.Groups) - if err != nil { - return nil, err + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + newRoute = &route.Route{ + ID: route.ID(xid.New().String()), + AccountID: accountID, + Network: prefix, + Domains: domains, + KeepRoute: keepRoute, + NetID: netID, + Description: description, + Peer: peerID, + PeerGroups: peerGroupIDs, + NetworkType: networkType, + Masquerade: masquerade, + Metric: metric, + Enabled: enabled, + Groups: groups, + AccessControlGroups: accessControlGroupIDs, } - } - if len(accessControlGroupIDs) > 0 { - err = validateGroups(accessControlGroupIDs, account.Groups) - if err != nil { - return nil, err + if err = validateRoute(ctx, transaction, accountID, newRoute); err != nil { + return err } - } - err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains) + updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, newRoute) + }) if err != nil { return nil, err } - if metric < route.MinMetric || metric > route.MaxMetric { - return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric) - } - - if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" { - return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) - } - - err = validateGroups(groups, account.Groups) - if err != nil { - return nil, err - } - - newRoute.Peer = peerID - newRoute.PeerGroups = peerGroupIDs - newRoute.Network = prefix - newRoute.Domains = domains - newRoute.NetworkType = networkType - newRoute.Description = description - newRoute.NetID = netID - newRoute.Masquerade = masquerade - newRoute.Metric = metric - newRoute.Enabled = enabled - newRoute.Groups = groups - newRoute.KeepRoute = keepRoute - newRoute.AccessControlGroups = accessControlGroupIDs - - if account.Routes == nil { - account.Routes = make(map[route.ID]*route.Route) - } - - account.Routes[newRoute.ID] = &newRoute - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return nil, err - } - - if am.isRouteChangeAffectPeers(account, &newRoute) { - am.UpdateAccountPeers(ctx, accountID) - } - am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) - return &newRoute, nil + if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) + } + + return newRoute, nil } // SaveRoute saves route @@ -226,6 +205,115 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + var oldRoute *route.Route + var oldRouteAffectsPeers bool + var newRouteAffectsPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil { + return err + } + + oldRoute, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeToSave.ID)) + if err != nil { + return err + } + + oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute) + if err != nil { + return err + } + + newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave) + if err != nil { + return err + } + routeToSave.AccountID = accountID + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, routeToSave) + }) + if err != nil { + return err + } + + am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) + + if oldRouteAffectsPeers || newRouteAffectsPeers { + am.UpdateAccountPeers(ctx, accountID) + } + + return nil +} + +// DeleteRoute deletes route with routeID +func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !allowed { + return status.NewPermissionDeniedError() + } + + var route *route.Route + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID)) + if err != nil { + return err + } + + updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeleteRoute(ctx, store.LockingStrengthUpdate, accountID, string(routeID)) + }) + + am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) + + if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) + } + + return nil +} + +// ListRoutes returns a list of routes from account +func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) +} + +func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error { if routeToSave == nil { return status.Errorf(status.InvalidArgument, "route provided is nil") } @@ -238,19 +326,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") } @@ -267,96 +342,39 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time") } + groupsMap, err := validateRouteGroups(ctx, transaction, accountID, routeToSave) + if err != nil { + return err + } + + return checkRoutePrefixOrDomainsExistForPeers(ctx, transaction, accountID, routeToSave, groupsMap) +} + +// validateRouteGroups validates the route groups and returns the validated groups map. +func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) { + groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups) + groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate) + if err != nil { + return nil, err + } + if len(routeToSave.PeerGroups) > 0 { - err = validateGroups(routeToSave.PeerGroups, account.Groups) - if err != nil { - return err + if err = validateGroups(routeToSave.PeerGroups, groupsMap); err != nil { + return nil, err } } if len(routeToSave.AccessControlGroups) > 0 { - err = validateGroups(routeToSave.AccessControlGroups, account.Groups) - if err != nil { - return err + if err = validateGroups(routeToSave.AccessControlGroups, groupsMap); err != nil { + return nil, err } } - err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains) - if err != nil { - return err + if err = validateGroups(routeToSave.Groups, groupsMap); err != nil { + return nil, err } - err = validateGroups(routeToSave.Groups, account.Groups) - if err != nil { - return err - } - - oldRoute := account.Routes[routeToSave.ID] - account.Routes[routeToSave.ID] = routeToSave - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) { - am.UpdateAccountPeers(ctx, accountID) - } - - am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) - - return nil -} - -// DeleteRoute deletes route with routeID -func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete) - if err != nil { - return status.NewPermissionValidationError(err) - } - if !allowed { - return status.NewPermissionDeniedError() - } - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - routy := account.Routes[routeID] - if routy == nil { - return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID) - } - delete(account.Routes, routeID) - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) - - if am.isRouteChangeAffectPeers(account, routy) { - am.UpdateAccountPeers(ctx, accountID) - } - - return nil -} - -// ListRoutes returns a list of routes from account -func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read) - if err != nil { - return nil, status.NewPermissionValidationError(err) - } - if !allowed { - return nil, status.NewPermissionDeniedError() - } - - return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) + return groupsMap, nil } func toProtocolRoute(route *route.Route) *proto.Route { @@ -455,8 +473,40 @@ func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo { return &portInfo } -// isRouteChangeAffectPeers checks if a given route affects peers by determining -// if it has a routing peer, distribution, or peer groups that include peers -func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *types.Account, route *route.Route) bool { - return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" +// areRouteChangesAffectPeers checks if a given route affects peers by determining +// if it has a routing peer, distribution, or peer groups that include peers. +func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) { + if route.Peer != "" { + return true, nil + } + + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.Groups) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + + return anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.PeerGroups) +} + +// GetRoutesByPrefixOrDomains return list of routes by account and route prefix +func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) { + accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + routes := make([]*route.Route, 0) + for _, r := range accountRoutes { + dynamic := r.IsDynamic() + if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() || + !dynamic && r.Network.String() == prefix.String() { + routes = append(routes, r) + } + } + + return routes, nil } diff --git a/management/server/status/error.go b/management/server/status/error.go index 8fbe0bad9..5a6f6d1a7 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -227,3 +227,7 @@ func NewUserRoleNotFoundError(role string) error { func NewOperationNotFoundError(operation operations.Operation) error { return Errorf(NotFound, "operation: %s not found", operation) } + +func NewRouteNotFoundError(routeID string) error { + return Errorf(NotFound, "route: %s not found", routeID) +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index d81890775..a6c4d56bf 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -23,8 +23,6 @@ import ( "gorm.io/gorm/clause" "gorm.io/gorm/logger" - "github.com/netbirdio/netbird/management/server/util" - nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" @@ -34,6 +32,7 @@ import ( "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" ) @@ -1968,12 +1967,58 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking // GetAccountRoutes retrieves network routes for an account. func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { - return getRecords[*route.Route](s.db, lockStrength, accountID) + var routes []*route.Route + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&routes, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get routes from store") + } + + return routes, nil } // GetRouteByID retrieves a route by its ID and account ID. -func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) { - return getRecordByID[route.Route](s.db, lockStrength, routeID, accountID) +func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) { + var route *route.Route + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&route, accountAndIDQueryCondition, accountID, routeID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewRouteNotFoundError(routeID) + } + log.WithContext(ctx).Errorf("failed to get route from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get route from store") + } + + return route, nil +} + +// SaveRoute saves a route to the database. +func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save route to the store: %s", err) + return status.Errorf(status.Internal, "failed to save route to store") + } + + return nil +} + +// DeleteRoute deletes a route from the database. +func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete route from store") + } + + if result.RowsAffected == 0 { + return status.NewRouteNotFoundError(routeID) + } + + return nil } // GetAccountSetupKeys retrieves setup keys for an account. @@ -2104,49 +2149,6 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki return nil } -// getRecords retrieves records from the database based on the account ID. -func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { - tx := db - if lockStrength != LockingStrengthNone { - tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) - } - - var record []T - - result := tx.Find(&record, accountIDCondition, accountID) - if err := result.Error; err != nil { - parts := strings.Split(fmt.Sprintf("%T", record), ".") - recordType := parts[len(parts)-1] - - return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err) - } - - return record, nil -} - -// getRecordByID retrieves a record by its ID and account ID from the database. -func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) { - tx := db - if lockStrength != LockingStrengthNone { - tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) - } - - var record T - - result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}). - First(&record, accountAndIDQueryCondition, accountID, recordID) - if err := result.Error; err != nil { - parts := strings.Split(fmt.Sprintf("%T", record), ".") - recordType := parts[len(parts)-1] - - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "%s not found", recordType) - } - return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err) - } - return &record, nil -} - // SaveDNSSettings saves the DNS settings to the store. func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 2c1f5f8e6..fab9048e5 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -19,21 +19,17 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/server/util" - nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/types" - - route2 "github.com/netbirdio/netbird/route" - - "github.com/netbirdio/netbird/management/server/status" - nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" nbroute "github.com/netbirdio/netbird/route" + route2 "github.com/netbirdio/netbird/route" ) func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) { @@ -3247,6 +3243,132 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) { require.NoError(t, err) require.Equal(t, 8003, len(accountGroups)) } +func TestSqlStore_GetAccountRoutes(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "retrieve routes by existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + { + name: "non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + routes, err := store.GetAccountRoutes(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, routes, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetRouteByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + routeID string + expectError bool + }{ + { + name: "retrieve existing route", + routeID: "ct03t427qv97vmtmglog", + expectError: false, + }, + { + name: "retrieve non-existing route", + routeID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty route ID", + routeID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, tt.routeID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, route) + } else { + require.NoError(t, err) + require.NotNil(t, route) + require.Equal(t, tt.routeID, string(route.ID)) + } + }) + } +} + +func TestSqlStore_SaveRoute(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + route := &route2.Route{ + ID: "route-id", + AccountID: accountID, + Network: netip.MustParsePrefix("10.10.0.0/16"), + NetID: "netID", + PeerGroups: []string{"routeA"}, + NetworkType: route2.IPv4Network, + Masquerade: true, + Metric: 9999, + Enabled: true, + Groups: []string{"groupA"}, + AccessControlGroups: []string{}, + } + err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route) + require.NoError(t, err) + + saveRoute, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, string(route.ID)) + require.NoError(t, err) + require.Equal(t, route, saveRoute) + +} + +func TestSqlStore_DeleteRoute(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + routeID := "ct03t427qv97vmtmglog" + + err = store.DeleteRoute(context.Background(), LockingStrengthUpdate, accountID, routeID) + require.NoError(t, err) + + route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, routeID) + require.Error(t, err) + require.Nil(t, route) +} func TestSqlStore_GetAccountMeta(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) diff --git a/management/server/store/store.go b/management/server/store/store.go index c7b103454..d41379b1c 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -145,7 +145,9 @@ type Store interface { DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) - GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) + GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error) + SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error + DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index 324bf6293..0393d1ade 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -38,4 +38,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-3465 INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}'); INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0); +INSERT INTO routes VALUES('ct03t427qv97vmtmglog','bf1c8084-ba50-4ce7-9439-34653001fc3b','"10.10.0.0/16"',NULL,0,'aws-eu-central-1-vpc','Production VPC in Frankfurt','ct03r5q7qv97vmtmglng',NULL,1,1,9999,1,'["cfefqs706sqkneg59g2g"]',NULL); INSERT INTO installations VALUES(1,'');