diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index 72c4758f4..8ebdc63e5 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -14,23 +14,15 @@ import ( ) type RouteSelector struct { - mu sync.RWMutex - selectedRoutes map[route.NetID]struct{} - selectAll bool - - // Indicates if new routes should be automatically selected - includeNewRoutes bool - - // All known routes at the time of deselection - knownRoutes []route.NetID + mu sync.RWMutex + deselectedRoutes map[route.NetID]struct{} + deselectAll bool } func NewRouteSelector() *RouteSelector { return &RouteSelector{ - selectedRoutes: map[route.NetID]struct{}{}, - selectAll: true, - includeNewRoutes: false, - knownRoutes: []route.NetID{}, + deselectedRoutes: map[route.NetID]struct{}{}, + deselectAll: false, } } @@ -39,8 +31,11 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al rs.mu.Lock() defer rs.mu.Unlock() - if !appendRoute { - rs.selectedRoutes = map[route.NetID]struct{}{} + if !appendRoute || rs.deselectAll { + maps.Clear(rs.deselectedRoutes) + for _, r := range allRoutes { + rs.deselectedRoutes[r] = struct{}{} + } } var err *multierror.Error @@ -49,11 +44,10 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route)) continue } - - rs.selectedRoutes[route] = struct{}{} + delete(rs.deselectedRoutes, route) } - rs.selectAll = false - rs.includeNewRoutes = false + + rs.deselectAll = false return errors.FormatErrorOrNil(err) } @@ -63,38 +57,26 @@ func (rs *RouteSelector) SelectAllRoutes() { rs.mu.Lock() defer rs.mu.Unlock() - rs.selectAll = true - rs.selectedRoutes = map[route.NetID]struct{}{} - rs.includeNewRoutes = false + rs.deselectAll = false + maps.Clear(rs.deselectedRoutes) } // DeselectRoutes removes specific routes from the selection. -// If the selector is in "select all" mode, it will transition to "select specific" mode -// but will keep new routes selected. func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error { rs.mu.Lock() defer rs.mu.Unlock() - if rs.selectAll { - rs.selectAll = false - rs.includeNewRoutes = true - rs.knownRoutes = make([]route.NetID, len(allRoutes)) - copy(rs.knownRoutes, allRoutes) - - rs.selectedRoutes = map[route.NetID]struct{}{} - for _, route := range allRoutes { - rs.selectedRoutes[route] = struct{}{} - } + if rs.deselectAll { + return nil } var err *multierror.Error - for _, route := range routes { if !slices.Contains(allRoutes, route) { err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route)) continue } - delete(rs.selectedRoutes, route) + rs.deselectedRoutes[route] = struct{}{} } return errors.FormatErrorOrNil(err) @@ -105,9 +87,8 @@ func (rs *RouteSelector) DeselectAllRoutes() { rs.mu.Lock() defer rs.mu.Unlock() - rs.selectAll = false - rs.includeNewRoutes = false - rs.selectedRoutes = map[route.NetID]struct{}{} + rs.deselectAll = true + maps.Clear(rs.deselectedRoutes) } // IsSelected checks if a specific route is selected. @@ -115,23 +96,12 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { rs.mu.RLock() defer rs.mu.RUnlock() - if rs.selectAll { - return true + if rs.deselectAll { + return false } - // Check if the route exists in selectedRoutes - _, selected := rs.selectedRoutes[routeID] - if selected { - return true - } - - // If includeNewRoutes is true and this is a new route (not in knownRoutes), - // then it should be selected - if rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, routeID) { - return true - } - - return false + _, deselected := rs.deselectedRoutes[routeID] + return !deselected } // FilterSelected removes unselected routes from the provided map. @@ -139,17 +109,15 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { rs.mu.RLock() defer rs.mu.RUnlock() - if rs.selectAll { - return maps.Clone(routes) + if rs.deselectAll { + return route.HAMap{} } filtered := route.HAMap{} for id, rt := range routes { netID := id.NetID() - _, selected := rs.selectedRoutes[netID] - - // Include if directly selected or if it's a new route and includeNewRoutes is true - if selected || (rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, netID)) { + _, deselected := rs.deselectedRoutes[netID] + if !deselected { filtered[id] = rt } } @@ -162,15 +130,11 @@ func (rs *RouteSelector) MarshalJSON() ([]byte, error) { defer rs.mu.RUnlock() return json.Marshal(struct { - SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` - SelectAll bool `json:"select_all"` - IncludeNewRoutes bool `json:"include_new_routes"` - KnownRoutes []route.NetID `json:"known_routes"` + DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"` + DeselectAll bool `json:"deselect_all"` }{ - SelectAll: rs.selectAll, - SelectedRoutes: rs.selectedRoutes, - IncludeNewRoutes: rs.includeNewRoutes, - KnownRoutes: rs.knownRoutes, + DeselectedRoutes: rs.deselectedRoutes, + DeselectAll: rs.deselectAll, }) } @@ -182,34 +146,25 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error { // Check for null or empty JSON if len(data) == 0 || string(data) == "null" { - rs.selectedRoutes = map[route.NetID]struct{}{} - rs.selectAll = true - rs.includeNewRoutes = false - rs.knownRoutes = []route.NetID{} + rs.deselectedRoutes = map[route.NetID]struct{}{} + rs.deselectAll = false return nil } var temp struct { - SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` - SelectAll bool `json:"select_all"` - IncludeNewRoutes bool `json:"include_new_routes"` - KnownRoutes []route.NetID `json:"known_routes"` + DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"` + DeselectAll bool `json:"deselect_all"` } if err := json.Unmarshal(data, &temp); err != nil { return err } - rs.selectedRoutes = temp.SelectedRoutes - rs.selectAll = temp.SelectAll - rs.includeNewRoutes = temp.IncludeNewRoutes - rs.knownRoutes = temp.KnownRoutes + rs.deselectedRoutes = temp.DeselectedRoutes + rs.deselectAll = temp.DeselectAll - if rs.selectedRoutes == nil { - rs.selectedRoutes = map[route.NetID]struct{}{} - } - if rs.knownRoutes == nil { - rs.knownRoutes = []route.NetID{} + if rs.deselectedRoutes == nil { + rs.deselectedRoutes = map[route.NetID]struct{}{} } return nil diff --git a/client/internal/routeselector/routeselector_test.go b/client/internal/routeselector/routeselector_test.go index a1461dff6..cfa723246 100644 --- a/client/internal/routeselector/routeselector_test.go +++ b/client/internal/routeselector/routeselector_test.go @@ -66,12 +66,10 @@ func TestRouteSelector_SelectRoutes(t *testing.T) { t.Run(tt.name, func(t *testing.T) { rs := routeselector.NewRouteSelector() - if tt.initialSelected != nil { - err := rs.SelectRoutes(tt.initialSelected, false, allRoutes) - require.NoError(t, err) - } + err := rs.SelectRoutes(tt.initialSelected, false, allRoutes) + require.NoError(t, err) - err := rs.SelectRoutes(tt.selectRoutes, tt.append, allRoutes) + err = rs.SelectRoutes(tt.selectRoutes, tt.append, allRoutes) if tt.wantError { assert.Error(t, err) } else { @@ -251,7 +249,8 @@ func TestRouteSelector_IsSelected(t *testing.T) { assert.True(t, rs.IsSelected("route1")) assert.True(t, rs.IsSelected("route2")) assert.False(t, rs.IsSelected("route3")) - assert.False(t, rs.IsSelected("route4")) + // Unknown route is selected by default + assert.True(t, rs.IsSelected("route4")) } func TestRouteSelector_FilterSelected(t *testing.T) { @@ -297,8 +296,8 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) { initialState: func(rs *routeselector.RouteSelector) error { return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, initialRoutes) }, - // When specific routes were selected, new routes should remain unselected - wantNewSelected: []route.NetID{"route1", "route2"}, + // When specific routes were selected, new routes should be selected + wantNewSelected: []route.NetID{"route1", "route2", "route4", "route5"}, }, { name: "New routes after deselect all", @@ -315,7 +314,7 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) { rs.SelectAllRoutes() return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes) }, - // After deselecting specific routes, new routes should remain unselected + // After deselecting specific routes, new routes should be selected wantNewSelected: []route.NetID{"route2", "route3", "route4", "route5"}, }, { @@ -323,8 +322,8 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) { initialState: func(rs *routeselector.RouteSelector) error { return rs.SelectRoutes([]route.NetID{"route1"}, true, initialRoutes) }, - // When routes were appended, new routes should remain unselected - wantNewSelected: []route.NetID{"route1"}, + // When routes were appended, new routes should be selected + wantNewSelected: []route.NetID{"route1", "route2", "route3", "route4", "route5"}, }, } @@ -428,3 +427,213 @@ func TestRouteSelector_MixedSelectionDeselection(t *testing.T) { }) } } + +func TestRouteSelector_AfterDeselectAll(t *testing.T) { + allRoutes := []route.NetID{"route1", "route2", "route3"} + + tests := []struct { + name string + initialAction func(rs *routeselector.RouteSelector) error + secondAction func(rs *routeselector.RouteSelector) error + wantSelected []route.NetID + wantError bool + }{ + { + name: "Deselect all -> select specific routes", + initialAction: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + secondAction: func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, allRoutes) + }, + wantSelected: []route.NetID{"route1", "route2"}, + }, + { + name: "Deselect all -> select with append", + initialAction: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + secondAction: func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1"}, true, allRoutes) + }, + wantSelected: []route.NetID{"route1"}, + }, + { + name: "Deselect all -> deselect specific", + initialAction: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + secondAction: func(rs *routeselector.RouteSelector) error { + return rs.DeselectRoutes([]route.NetID{"route1"}, allRoutes) + }, + wantSelected: []route.NetID{}, + }, + { + name: "Deselect all -> select all", + initialAction: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + secondAction: func(rs *routeselector.RouteSelector) error { + rs.SelectAllRoutes() + return nil + }, + wantSelected: []route.NetID{"route1", "route2", "route3"}, + }, + { + name: "Deselect all -> deselect non-existent route", + initialAction: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + secondAction: func(rs *routeselector.RouteSelector) error { + return rs.DeselectRoutes([]route.NetID{"route4"}, allRoutes) + }, + wantSelected: []route.NetID{}, + wantError: false, + }, + { + name: "Select specific -> deselect all -> select different", + initialAction: func(rs *routeselector.RouteSelector) error { + err := rs.SelectRoutes([]route.NetID{"route1"}, false, allRoutes) + if err != nil { + return err + } + rs.DeselectAllRoutes() + return nil + }, + secondAction: func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route2", "route3"}, false, allRoutes) + }, + wantSelected: []route.NetID{"route2", "route3"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := routeselector.NewRouteSelector() + + err := tt.initialAction(rs) + require.NoError(t, err) + + err = tt.secondAction(rs) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + for _, id := range allRoutes { + expected := slices.Contains(tt.wantSelected, id) + assert.Equal(t, expected, rs.IsSelected(id), + "Route %s selection state incorrect, expected %v", id, expected) + } + + routes := route.HAMap{ + "route1|10.0.0.0/8": {}, + "route2|192.168.0.0/16": {}, + "route3|172.16.0.0/12": {}, + } + + filtered := rs.FilterSelected(routes) + assert.Equal(t, len(tt.wantSelected), len(filtered), + "FilterSelected returned wrong number of routes") + }) + } +} + +func TestRouteSelector_ComplexScenarios(t *testing.T) { + allRoutes := []route.NetID{"route1", "route2", "route3", "route4"} + + tests := []struct { + name string + actions []func(rs *routeselector.RouteSelector) error + wantSelected []route.NetID + }{ + { + name: "Select all -> deselect specific -> select different with append", + actions: []func(rs *routeselector.RouteSelector) error{ + func(rs *routeselector.RouteSelector) error { + rs.SelectAllRoutes() + return nil + }, + func(rs *routeselector.RouteSelector) error { + return rs.DeselectRoutes([]route.NetID{"route1", "route2"}, allRoutes) + }, + func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1"}, true, allRoutes) + }, + }, + wantSelected: []route.NetID{"route1", "route3", "route4"}, + }, + { + name: "Deselect all -> select specific -> deselect one -> select different with append", + actions: []func(rs *routeselector.RouteSelector) error{ + func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, allRoutes) + }, + func(rs *routeselector.RouteSelector) error { + return rs.DeselectRoutes([]route.NetID{"route2"}, allRoutes) + }, + func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route3"}, true, allRoutes) + }, + }, + wantSelected: []route.NetID{"route1", "route3"}, + }, + { + name: "Select specific -> deselect specific -> select all -> deselect different", + actions: []func(rs *routeselector.RouteSelector) error{ + func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, allRoutes) + }, + func(rs *routeselector.RouteSelector) error { + return rs.DeselectRoutes([]route.NetID{"route2"}, allRoutes) + }, + func(rs *routeselector.RouteSelector) error { + rs.SelectAllRoutes() + return nil + }, + func(rs *routeselector.RouteSelector) error { + return rs.DeselectRoutes([]route.NetID{"route3", "route4"}, allRoutes) + }, + }, + wantSelected: []route.NetID{"route1", "route2"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := routeselector.NewRouteSelector() + + for i, action := range tt.actions { + err := action(rs) + require.NoError(t, err, "Action %d failed", i) + } + + for _, id := range allRoutes { + expected := slices.Contains(tt.wantSelected, id) + assert.Equal(t, expected, rs.IsSelected(id), + "Route %s selection state incorrect", id) + } + + routes := route.HAMap{ + "route1|10.0.0.0/8": {}, + "route2|192.168.0.0/16": {}, + "route3|172.16.0.0/12": {}, + "route4|10.10.0.0/16": {}, + } + + filtered := rs.FilterSelected(routes) + assert.Equal(t, len(tt.wantSelected), len(filtered), + "FilterSelected returned wrong number of routes") + }) + } +}