mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-28 13:42:31 +02:00
[client] Keep new routes selected unless all are deselected (#3692)
This commit is contained in:
parent
1a6d6b3109
commit
3b7b9d25bc
@ -14,23 +14,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type RouteSelector struct {
|
type RouteSelector struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
selectedRoutes map[route.NetID]struct{}
|
deselectedRoutes map[route.NetID]struct{}
|
||||||
selectAll bool
|
deselectAll bool
|
||||||
|
|
||||||
// Indicates if new routes should be automatically selected
|
|
||||||
includeNewRoutes bool
|
|
||||||
|
|
||||||
// All known routes at the time of deselection
|
|
||||||
knownRoutes []route.NetID
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRouteSelector() *RouteSelector {
|
func NewRouteSelector() *RouteSelector {
|
||||||
return &RouteSelector{
|
return &RouteSelector{
|
||||||
selectedRoutes: map[route.NetID]struct{}{},
|
deselectedRoutes: map[route.NetID]struct{}{},
|
||||||
selectAll: true,
|
deselectAll: false,
|
||||||
includeNewRoutes: false,
|
|
||||||
knownRoutes: []route.NetID{},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -39,8 +31,11 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
|
|||||||
rs.mu.Lock()
|
rs.mu.Lock()
|
||||||
defer rs.mu.Unlock()
|
defer rs.mu.Unlock()
|
||||||
|
|
||||||
if !appendRoute {
|
if !appendRoute || rs.deselectAll {
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
maps.Clear(rs.deselectedRoutes)
|
||||||
|
for _, r := range allRoutes {
|
||||||
|
rs.deselectedRoutes[r] = struct{}{}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var err *multierror.Error
|
var err *multierror.Error
|
||||||
@ -49,11 +44,10 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
|
|||||||
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
|
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
delete(rs.deselectedRoutes, route)
|
||||||
rs.selectedRoutes[route] = struct{}{}
|
|
||||||
}
|
}
|
||||||
rs.selectAll = false
|
|
||||||
rs.includeNewRoutes = false
|
rs.deselectAll = false
|
||||||
|
|
||||||
return errors.FormatErrorOrNil(err)
|
return errors.FormatErrorOrNil(err)
|
||||||
}
|
}
|
||||||
@ -63,38 +57,26 @@ func (rs *RouteSelector) SelectAllRoutes() {
|
|||||||
rs.mu.Lock()
|
rs.mu.Lock()
|
||||||
defer rs.mu.Unlock()
|
defer rs.mu.Unlock()
|
||||||
|
|
||||||
rs.selectAll = true
|
rs.deselectAll = false
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
maps.Clear(rs.deselectedRoutes)
|
||||||
rs.includeNewRoutes = false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeselectRoutes removes specific routes from the selection.
|
// DeselectRoutes removes specific routes from the selection.
|
||||||
// If the selector is in "select all" mode, it will transition to "select specific" mode
|
|
||||||
// but will keep new routes selected.
|
|
||||||
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
|
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
|
||||||
rs.mu.Lock()
|
rs.mu.Lock()
|
||||||
defer rs.mu.Unlock()
|
defer rs.mu.Unlock()
|
||||||
|
|
||||||
if rs.selectAll {
|
if rs.deselectAll {
|
||||||
rs.selectAll = false
|
return nil
|
||||||
rs.includeNewRoutes = true
|
|
||||||
rs.knownRoutes = make([]route.NetID, len(allRoutes))
|
|
||||||
copy(rs.knownRoutes, allRoutes)
|
|
||||||
|
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
|
||||||
for _, route := range allRoutes {
|
|
||||||
rs.selectedRoutes[route] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var err *multierror.Error
|
var err *multierror.Error
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
if !slices.Contains(allRoutes, route) {
|
if !slices.Contains(allRoutes, route) {
|
||||||
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
|
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
delete(rs.selectedRoutes, route)
|
rs.deselectedRoutes[route] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return errors.FormatErrorOrNil(err)
|
return errors.FormatErrorOrNil(err)
|
||||||
@ -105,9 +87,8 @@ func (rs *RouteSelector) DeselectAllRoutes() {
|
|||||||
rs.mu.Lock()
|
rs.mu.Lock()
|
||||||
defer rs.mu.Unlock()
|
defer rs.mu.Unlock()
|
||||||
|
|
||||||
rs.selectAll = false
|
rs.deselectAll = true
|
||||||
rs.includeNewRoutes = false
|
maps.Clear(rs.deselectedRoutes)
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsSelected checks if a specific route is selected.
|
// IsSelected checks if a specific route is selected.
|
||||||
@ -115,23 +96,12 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
|||||||
rs.mu.RLock()
|
rs.mu.RLock()
|
||||||
defer rs.mu.RUnlock()
|
defer rs.mu.RUnlock()
|
||||||
|
|
||||||
if rs.selectAll {
|
if rs.deselectAll {
|
||||||
return true
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the route exists in selectedRoutes
|
_, deselected := rs.deselectedRoutes[routeID]
|
||||||
_, selected := rs.selectedRoutes[routeID]
|
return !deselected
|
||||||
if selected {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// If includeNewRoutes is true and this is a new route (not in knownRoutes),
|
|
||||||
// then it should be selected
|
|
||||||
if rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, routeID) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilterSelected removes unselected routes from the provided map.
|
// FilterSelected removes unselected routes from the provided map.
|
||||||
@ -139,17 +109,15 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
|||||||
rs.mu.RLock()
|
rs.mu.RLock()
|
||||||
defer rs.mu.RUnlock()
|
defer rs.mu.RUnlock()
|
||||||
|
|
||||||
if rs.selectAll {
|
if rs.deselectAll {
|
||||||
return maps.Clone(routes)
|
return route.HAMap{}
|
||||||
}
|
}
|
||||||
|
|
||||||
filtered := route.HAMap{}
|
filtered := route.HAMap{}
|
||||||
for id, rt := range routes {
|
for id, rt := range routes {
|
||||||
netID := id.NetID()
|
netID := id.NetID()
|
||||||
_, selected := rs.selectedRoutes[netID]
|
_, deselected := rs.deselectedRoutes[netID]
|
||||||
|
if !deselected {
|
||||||
// Include if directly selected or if it's a new route and includeNewRoutes is true
|
|
||||||
if selected || (rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, netID)) {
|
|
||||||
filtered[id] = rt
|
filtered[id] = rt
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -162,15 +130,11 @@ func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
|
|||||||
defer rs.mu.RUnlock()
|
defer rs.mu.RUnlock()
|
||||||
|
|
||||||
return json.Marshal(struct {
|
return json.Marshal(struct {
|
||||||
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
|
DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"`
|
||||||
SelectAll bool `json:"select_all"`
|
DeselectAll bool `json:"deselect_all"`
|
||||||
IncludeNewRoutes bool `json:"include_new_routes"`
|
|
||||||
KnownRoutes []route.NetID `json:"known_routes"`
|
|
||||||
}{
|
}{
|
||||||
SelectAll: rs.selectAll,
|
DeselectedRoutes: rs.deselectedRoutes,
|
||||||
SelectedRoutes: rs.selectedRoutes,
|
DeselectAll: rs.deselectAll,
|
||||||
IncludeNewRoutes: rs.includeNewRoutes,
|
|
||||||
KnownRoutes: rs.knownRoutes,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -182,34 +146,25 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
|
|||||||
|
|
||||||
// Check for null or empty JSON
|
// Check for null or empty JSON
|
||||||
if len(data) == 0 || string(data) == "null" {
|
if len(data) == 0 || string(data) == "null" {
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
rs.deselectedRoutes = map[route.NetID]struct{}{}
|
||||||
rs.selectAll = true
|
rs.deselectAll = false
|
||||||
rs.includeNewRoutes = false
|
|
||||||
rs.knownRoutes = []route.NetID{}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var temp struct {
|
var temp struct {
|
||||||
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
|
DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"`
|
||||||
SelectAll bool `json:"select_all"`
|
DeselectAll bool `json:"deselect_all"`
|
||||||
IncludeNewRoutes bool `json:"include_new_routes"`
|
|
||||||
KnownRoutes []route.NetID `json:"known_routes"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(data, &temp); err != nil {
|
if err := json.Unmarshal(data, &temp); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
rs.selectedRoutes = temp.SelectedRoutes
|
rs.deselectedRoutes = temp.DeselectedRoutes
|
||||||
rs.selectAll = temp.SelectAll
|
rs.deselectAll = temp.DeselectAll
|
||||||
rs.includeNewRoutes = temp.IncludeNewRoutes
|
|
||||||
rs.knownRoutes = temp.KnownRoutes
|
|
||||||
|
|
||||||
if rs.selectedRoutes == nil {
|
if rs.deselectedRoutes == nil {
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
rs.deselectedRoutes = map[route.NetID]struct{}{}
|
||||||
}
|
|
||||||
if rs.knownRoutes == nil {
|
|
||||||
rs.knownRoutes = []route.NetID{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -66,12 +66,10 @@ func TestRouteSelector_SelectRoutes(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
rs := routeselector.NewRouteSelector()
|
rs := routeselector.NewRouteSelector()
|
||||||
|
|
||||||
if tt.initialSelected != nil {
|
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
|
||||||
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
|
require.NoError(t, err)
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := rs.SelectRoutes(tt.selectRoutes, tt.append, allRoutes)
|
err = rs.SelectRoutes(tt.selectRoutes, tt.append, allRoutes)
|
||||||
if tt.wantError {
|
if tt.wantError {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
} else {
|
} else {
|
||||||
@ -251,7 +249,8 @@ func TestRouteSelector_IsSelected(t *testing.T) {
|
|||||||
assert.True(t, rs.IsSelected("route1"))
|
assert.True(t, rs.IsSelected("route1"))
|
||||||
assert.True(t, rs.IsSelected("route2"))
|
assert.True(t, rs.IsSelected("route2"))
|
||||||
assert.False(t, rs.IsSelected("route3"))
|
assert.False(t, rs.IsSelected("route3"))
|
||||||
assert.False(t, rs.IsSelected("route4"))
|
// Unknown route is selected by default
|
||||||
|
assert.True(t, rs.IsSelected("route4"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRouteSelector_FilterSelected(t *testing.T) {
|
func TestRouteSelector_FilterSelected(t *testing.T) {
|
||||||
@ -297,8 +296,8 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
|
|||||||
initialState: func(rs *routeselector.RouteSelector) error {
|
initialState: func(rs *routeselector.RouteSelector) error {
|
||||||
return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, initialRoutes)
|
return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, initialRoutes)
|
||||||
},
|
},
|
||||||
// When specific routes were selected, new routes should remain unselected
|
// When specific routes were selected, new routes should be selected
|
||||||
wantNewSelected: []route.NetID{"route1", "route2"},
|
wantNewSelected: []route.NetID{"route1", "route2", "route4", "route5"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "New routes after deselect all",
|
name: "New routes after deselect all",
|
||||||
@ -315,7 +314,7 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
|
|||||||
rs.SelectAllRoutes()
|
rs.SelectAllRoutes()
|
||||||
return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes)
|
return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes)
|
||||||
},
|
},
|
||||||
// After deselecting specific routes, new routes should remain unselected
|
// After deselecting specific routes, new routes should be selected
|
||||||
wantNewSelected: []route.NetID{"route2", "route3", "route4", "route5"},
|
wantNewSelected: []route.NetID{"route2", "route3", "route4", "route5"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -323,8 +322,8 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
|
|||||||
initialState: func(rs *routeselector.RouteSelector) error {
|
initialState: func(rs *routeselector.RouteSelector) error {
|
||||||
return rs.SelectRoutes([]route.NetID{"route1"}, true, initialRoutes)
|
return rs.SelectRoutes([]route.NetID{"route1"}, true, initialRoutes)
|
||||||
},
|
},
|
||||||
// When routes were appended, new routes should remain unselected
|
// When routes were appended, new routes should be selected
|
||||||
wantNewSelected: []route.NetID{"route1"},
|
wantNewSelected: []route.NetID{"route1", "route2", "route3", "route4", "route5"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -428,3 +427,213 @@ func TestRouteSelector_MixedSelectionDeselection(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRouteSelector_AfterDeselectAll(t *testing.T) {
|
||||||
|
allRoutes := []route.NetID{"route1", "route2", "route3"}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
initialAction func(rs *routeselector.RouteSelector) error
|
||||||
|
secondAction func(rs *routeselector.RouteSelector) error
|
||||||
|
wantSelected []route.NetID
|
||||||
|
wantError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Deselect all -> select specific routes",
|
||||||
|
initialAction: func(rs *routeselector.RouteSelector) error {
|
||||||
|
rs.DeselectAllRoutes()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
secondAction: func(rs *routeselector.RouteSelector) error {
|
||||||
|
return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, allRoutes)
|
||||||
|
},
|
||||||
|
wantSelected: []route.NetID{"route1", "route2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Deselect all -> select with append",
|
||||||
|
initialAction: func(rs *routeselector.RouteSelector) error {
|
||||||
|
rs.DeselectAllRoutes()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
secondAction: func(rs *routeselector.RouteSelector) error {
|
||||||
|
return rs.SelectRoutes([]route.NetID{"route1"}, true, allRoutes)
|
||||||
|
},
|
||||||
|
wantSelected: []route.NetID{"route1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Deselect all -> deselect specific",
|
||||||
|
initialAction: func(rs *routeselector.RouteSelector) error {
|
||||||
|
rs.DeselectAllRoutes()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
secondAction: func(rs *routeselector.RouteSelector) error {
|
||||||
|
return rs.DeselectRoutes([]route.NetID{"route1"}, allRoutes)
|
||||||
|
},
|
||||||
|
wantSelected: []route.NetID{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Deselect all -> select all",
|
||||||
|
initialAction: func(rs *routeselector.RouteSelector) error {
|
||||||
|
rs.DeselectAllRoutes()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
secondAction: func(rs *routeselector.RouteSelector) error {
|
||||||
|
rs.SelectAllRoutes()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantSelected: []route.NetID{"route1", "route2", "route3"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Deselect all -> deselect non-existent route",
|
||||||
|
initialAction: func(rs *routeselector.RouteSelector) error {
|
||||||
|
rs.DeselectAllRoutes()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
secondAction: func(rs *routeselector.RouteSelector) error {
|
||||||
|
return rs.DeselectRoutes([]route.NetID{"route4"}, allRoutes)
|
||||||
|
},
|
||||||
|
wantSelected: []route.NetID{},
|
||||||
|
wantError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Select specific -> deselect all -> select different",
|
||||||
|
initialAction: func(rs *routeselector.RouteSelector) error {
|
||||||
|
err := rs.SelectRoutes([]route.NetID{"route1"}, false, allRoutes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
rs.DeselectAllRoutes()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
secondAction: func(rs *routeselector.RouteSelector) error {
|
||||||
|
return rs.SelectRoutes([]route.NetID{"route2", "route3"}, false, allRoutes)
|
||||||
|
},
|
||||||
|
wantSelected: []route.NetID{"route2", "route3"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rs := routeselector.NewRouteSelector()
|
||||||
|
|
||||||
|
err := tt.initialAction(rs)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = tt.secondAction(rs)
|
||||||
|
if tt.wantError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, id := range allRoutes {
|
||||||
|
expected := slices.Contains(tt.wantSelected, id)
|
||||||
|
assert.Equal(t, expected, rs.IsSelected(id),
|
||||||
|
"Route %s selection state incorrect, expected %v", id, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
routes := route.HAMap{
|
||||||
|
"route1|10.0.0.0/8": {},
|
||||||
|
"route2|192.168.0.0/16": {},
|
||||||
|
"route3|172.16.0.0/12": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := rs.FilterSelected(routes)
|
||||||
|
assert.Equal(t, len(tt.wantSelected), len(filtered),
|
||||||
|
"FilterSelected returned wrong number of routes")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouteSelector_ComplexScenarios(t *testing.T) {
|
||||||
|
allRoutes := []route.NetID{"route1", "route2", "route3", "route4"}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
actions []func(rs *routeselector.RouteSelector) error
|
||||||
|
wantSelected []route.NetID
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Select all -> deselect specific -> select different with append",
|
||||||
|
actions: []func(rs *routeselector.RouteSelector) error{
|
||||||
|
func(rs *routeselector.RouteSelector) error {
|
||||||
|
rs.SelectAllRoutes()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
func(rs *routeselector.RouteSelector) error {
|
||||||
|
return rs.DeselectRoutes([]route.NetID{"route1", "route2"}, allRoutes)
|
||||||
|
},
|
||||||
|
func(rs *routeselector.RouteSelector) error {
|
||||||
|
return rs.SelectRoutes([]route.NetID{"route1"}, true, allRoutes)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantSelected: []route.NetID{"route1", "route3", "route4"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Deselect all -> select specific -> deselect one -> select different with append",
|
||||||
|
actions: []func(rs *routeselector.RouteSelector) error{
|
||||||
|
func(rs *routeselector.RouteSelector) error {
|
||||||
|
rs.DeselectAllRoutes()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
func(rs *routeselector.RouteSelector) error {
|
||||||
|
return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, allRoutes)
|
||||||
|
},
|
||||||
|
func(rs *routeselector.RouteSelector) error {
|
||||||
|
return rs.DeselectRoutes([]route.NetID{"route2"}, allRoutes)
|
||||||
|
},
|
||||||
|
func(rs *routeselector.RouteSelector) error {
|
||||||
|
return rs.SelectRoutes([]route.NetID{"route3"}, true, allRoutes)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantSelected: []route.NetID{"route1", "route3"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Select specific -> deselect specific -> select all -> deselect different",
|
||||||
|
actions: []func(rs *routeselector.RouteSelector) error{
|
||||||
|
func(rs *routeselector.RouteSelector) error {
|
||||||
|
return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, allRoutes)
|
||||||
|
},
|
||||||
|
func(rs *routeselector.RouteSelector) error {
|
||||||
|
return rs.DeselectRoutes([]route.NetID{"route2"}, allRoutes)
|
||||||
|
},
|
||||||
|
func(rs *routeselector.RouteSelector) error {
|
||||||
|
rs.SelectAllRoutes()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
func(rs *routeselector.RouteSelector) error {
|
||||||
|
return rs.DeselectRoutes([]route.NetID{"route3", "route4"}, allRoutes)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantSelected: []route.NetID{"route1", "route2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rs := routeselector.NewRouteSelector()
|
||||||
|
|
||||||
|
for i, action := range tt.actions {
|
||||||
|
err := action(rs)
|
||||||
|
require.NoError(t, err, "Action %d failed", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, id := range allRoutes {
|
||||||
|
expected := slices.Contains(tt.wantSelected, id)
|
||||||
|
assert.Equal(t, expected, rs.IsSelected(id),
|
||||||
|
"Route %s selection state incorrect", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
routes := route.HAMap{
|
||||||
|
"route1|10.0.0.0/8": {},
|
||||||
|
"route2|192.168.0.0/16": {},
|
||||||
|
"route3|172.16.0.0/12": {},
|
||||||
|
"route4|10.10.0.0/16": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := rs.FilterSelected(routes)
|
||||||
|
assert.Equal(t, len(tt.wantSelected), len(filtered),
|
||||||
|
"FilterSelected returned wrong number of routes")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user