Refactor route

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-10-25 13:18:24 +03:00
parent 0bdcb41e20
commit 313e158e20

View File

@ -52,16 +52,20 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if !user.IsAdminOrServiceUser() {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
} }
if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
}
return am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID)) return am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID))
} }
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix // GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func (am *DefaultAccountManager) GetRoutesByPrefixOrDomains(accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) { func (am *DefaultAccountManager) GetRoutesByPrefixOrDomains(ctx context.Context, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
accountRoutes, err := am.Store.GetAccountRoutes(context.Background(), LockingStrengthShare, accountID) accountRoutes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -79,9 +83,9 @@ func (am *DefaultAccountManager) GetRoutesByPrefixOrDomains(accountID string, pr
} }
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, accountID, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
// routes can have both peer and peer_groups // routes can have both peer and peer_groups
routesWithPrefix, err := am.GetRoutesByPrefixOrDomains(account.Id, prefix, domains) routesWithPrefix, err := am.GetRoutesByPrefixOrDomains(ctx, accountID, prefix, domains)
if err != nil { if err != nil {
return err return err
} }
@ -103,7 +107,7 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
for _, groupID := range prefixRoute.PeerGroups { for _, groupID := range prefixRoute.PeerGroups {
seenPeerGroups[groupID] = true seenPeerGroups[groupID] = true
group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, account.Id, groupID) group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
if err != nil || group == nil { if err != nil || group == nil {
return status.Errorf( return status.Errorf(
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist", status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
@ -119,7 +123,7 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
if peerID != "" { if peerID != "" {
// check that peerID exists and is not in any route as single peer or part of the group // check that peerID exists and is not in any route as single peer or part of the group
peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, account.Id, peerID) peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID)
if err != nil || peer == nil { if err != nil || peer == nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
@ -133,7 +137,7 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
// check that peerGroupIDs are not in any route peerGroups list // check that peerGroupIDs are not in any route peerGroups list
for _, groupID := range peerGroupIDs { for _, groupID := range peerGroupIDs {
// we validated the group existence before entering this function, no need to check again. // we validated the group existence before entering this function, no need to check again.
group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, groupID, account.Id) group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, groupID, accountID)
if err != nil || group == nil { if err != nil || group == nil {
return status.Errorf(status.InvalidArgument, "group with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "group with ID %s not found", peerID)
} }
@ -147,7 +151,7 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
for _, id := range group.Peers { for _, id := range group.Peers {
if _, ok := seenPeers[id]; ok { if _, ok := seenPeers[id]; ok {
peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, peerID, account.Id) peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, peerID, accountID)
if err != nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
@ -171,16 +175,13 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string {
// CreateRoute creates and saves a new route // CreateRoute creates and saves a new route
func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if user.AccountID != accountID { if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
} }
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
@ -189,7 +190,12 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
} }
// Do not allow non-Linux peers // Do not allow non-Linux peers
if peer := account.GetPeer(peerID); peer != nil { if peerID != "" {
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
return nil, err
}
if peer.Meta.GoOS != "linux" { if peer.Meta.GoOS != "linux" {
return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
} }
@ -217,21 +223,26 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
var newRoute route.Route var newRoute route.Route
newRoute.ID = route.ID(xid.New().String()) newRoute.ID = route.ID(xid.New().String())
accountGroups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
if len(peerGroupIDs) > 0 { if len(peerGroupIDs) > 0 {
//err = validateGroups(peerGroupIDs, account.Groups) err = validateGroups(peerGroupIDs, accountGroups)
//if err != nil { if err != nil {
// return nil, err return nil, err
//} }
} }
if len(accessControlGroupIDs) > 0 { if len(accessControlGroupIDs) > 0 {
//err = validateGroups(accessControlGroupIDs, account.Groups) err = validateGroups(accessControlGroupIDs, accountGroups)
//if err != nil { if err != nil {
// return nil, err return nil, err
//} }
} }
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains) err = am.checkRoutePrefixOrDomainsExistForPeers(ctx, accountID, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -244,10 +255,10 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
} }
//err = validateGroups(groups, account.Groups) err = validateGroups(groups, accountGroups)
//if err != nil { if err != nil {
// return nil, err return nil, err
//} }
newRoute.Peer = peerID newRoute.Peer = peerID
newRoute.PeerGroups = peerGroupIDs newRoute.PeerGroups = peerGroupIDs
@ -263,28 +274,43 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
newRoute.KeepRoute = keepRoute newRoute.KeepRoute = keepRoute
newRoute.AccessControlGroups = accessControlGroupIDs newRoute.AccessControlGroups = accessControlGroupIDs
if account.Routes == nil { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
account.Routes = make(map[route.ID]*route.Route) if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
} return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
account.Routes[newRoute.ID] = &newRoute err = transaction.SaveRoute(ctx, LockingStrengthUpdate, &newRoute)
if err != nil {
return fmt.Errorf("failed to create route: %w", err)
}
account.Network.IncSerial() return nil
if err = am.Store.SaveAccount(ctx, account); err != nil { })
if err != nil {
return nil, err return nil, err
} }
am.updateAccountPeers(ctx, account)
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account)
return &newRoute, nil return &newRoute, nil
} }
// SaveRoute saves route // SaveRoute saves route
func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error { func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock() if err != nil {
return err
}
if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
}
if routeToSave == nil { if routeToSave == nil {
return status.Errorf(status.InvalidArgument, "route provided is nil") return status.Errorf(status.InvalidArgument, "route provided is nil")
@ -298,18 +324,11 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
} }
account, err := am.Store.GetAccount(ctx, accountID) _, err = am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeToSave.ID))
if err != nil { if err != nil {
return err return err
} }
// Do not allow non-Linux peers
if peer := account.GetPeer(routeToSave.Peer); peer != nil {
if peer.Meta.GoOS != "linux" {
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
}
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
} }
@ -326,67 +345,107 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time") return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
} }
if len(routeToSave.PeerGroups) > 0 { // Do not allow non-Linux peers
//err = validateGroups(routeToSave.PeerGroups, account.Groups) if routeToSave.Peer != "" {
//if err != nil { peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, routeToSave.Peer)
// return err if err != nil {
//} return err
}
if peer.Meta.GoOS != "linux" {
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
} }
if len(routeToSave.AccessControlGroups) > 0 { groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
//err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
//if err != nil {
// return err
//}
}
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
if err != nil { if err != nil {
return err return err
} }
//err = validateGroups(routeToSave.Groups, account.Groups) if len(routeToSave.PeerGroups) > 0 {
//if err != nil { err = validateGroups(routeToSave.PeerGroups, groups)
// return err if err != nil {
//} return err
}
}
account.Routes[routeToSave.ID] = routeToSave if len(routeToSave.AccessControlGroups) > 0 {
err = validateGroups(routeToSave.AccessControlGroups, groups)
if err != nil {
return err
}
}
account.Network.IncSerial() err = am.checkRoutePrefixOrDomainsExistForPeers(ctx, accountID, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
if err = am.Store.SaveAccount(ctx, account); err != nil { if err != nil {
return err return err
} }
am.updateAccountPeers(ctx, account) err = validateGroups(routeToSave.Groups, groups)
if err != nil {
return err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
err = transaction.SaveRoute(ctx, LockingStrengthUpdate, routeToSave)
if err != nil {
return fmt.Errorf("failed to save route: %w", err)
}
return nil
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account)
return nil return nil
} }
// DeleteRoute deletes route with routeID // DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return err return err
} }
routy := account.Routes[routeID] if user.AccountID != accountID {
if routy == nil { return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID)
} }
delete(account.Routes, routeID)
account.Network.IncSerial() route, err := am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID))
if err = am.Store.SaveAccount(ctx, account); err != nil { if err != nil {
return err return err
} }
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf(errNetworkSerialIncrementFmt, err)
}
if err = transaction.DeleteRoute(ctx, LockingStrengthUpdate, accountID, string(routeID)); err != nil {
return fmt.Errorf("failed to delete route: %w", err)
}
return nil
})
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account) am.updateAccountPeers(ctx, account)
return nil return nil
@ -399,10 +458,14 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
return nil, err return nil, err
} }
if !user.IsAdminOrServiceUser() || user.AccountID != accountID { if !user.IsAdminOrServiceUser() {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
} }
if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg)
}
return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
} }