diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index 2874604fd..72c4758f4 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -10,20 +10,27 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/errors" - route "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/route" ) type RouteSelector struct { mu sync.RWMutex selectedRoutes map[route.NetID]struct{} selectAll bool + + // Indicates if new routes should be automatically selected + includeNewRoutes bool + + // All known routes at the time of deselection + knownRoutes []route.NetID } func NewRouteSelector() *RouteSelector { return &RouteSelector{ - selectedRoutes: map[route.NetID]struct{}{}, - // default selects all routes - selectAll: true, + selectedRoutes: map[route.NetID]struct{}{}, + selectAll: true, + includeNewRoutes: false, + knownRoutes: []route.NetID{}, } } @@ -46,6 +53,7 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al rs.selectedRoutes[route] = struct{}{} } rs.selectAll = false + rs.includeNewRoutes = false return errors.FormatErrorOrNil(err) } @@ -57,16 +65,22 @@ func (rs *RouteSelector) SelectAllRoutes() { rs.selectAll = true rs.selectedRoutes = map[route.NetID]struct{}{} + rs.includeNewRoutes = false } // DeselectRoutes removes specific routes from the selection. -// If the selector is in "select all" mode, it will transition to "select specific" mode. +// If the selector is in "select all" mode, it will transition to "select specific" mode +// but will keep new routes selected. func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error { rs.mu.Lock() defer rs.mu.Unlock() if rs.selectAll { rs.selectAll = false + rs.includeNewRoutes = true + rs.knownRoutes = make([]route.NetID, len(allRoutes)) + copy(rs.knownRoutes, allRoutes) + rs.selectedRoutes = map[route.NetID]struct{}{} for _, route := range allRoutes { rs.selectedRoutes[route] = struct{}{} @@ -92,6 +106,7 @@ func (rs *RouteSelector) DeselectAllRoutes() { defer rs.mu.Unlock() rs.selectAll = false + rs.includeNewRoutes = false rs.selectedRoutes = map[route.NetID]struct{}{} } @@ -103,8 +118,20 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { if rs.selectAll { return true } + + // Check if the route exists in selectedRoutes _, selected := rs.selectedRoutes[routeID] - return selected + if selected { + return true + } + + // If includeNewRoutes is true and this is a new route (not in knownRoutes), + // then it should be selected + if rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, routeID) { + return true + } + + return false } // FilterSelected removes unselected routes from the provided map. @@ -118,7 +145,11 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { filtered := route.HAMap{} for id, rt := range routes { - if rs.IsSelected(id.NetID()) { + netID := id.NetID() + _, selected := rs.selectedRoutes[netID] + + // Include if directly selected or if it's a new route and includeNewRoutes is true + if selected || (rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, netID)) { filtered[id] = rt } } @@ -131,11 +162,15 @@ func (rs *RouteSelector) MarshalJSON() ([]byte, error) { defer rs.mu.RUnlock() return json.Marshal(struct { - SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` - SelectAll bool `json:"select_all"` + SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` + SelectAll bool `json:"select_all"` + IncludeNewRoutes bool `json:"include_new_routes"` + KnownRoutes []route.NetID `json:"known_routes"` }{ - SelectAll: rs.selectAll, - SelectedRoutes: rs.selectedRoutes, + SelectAll: rs.selectAll, + SelectedRoutes: rs.selectedRoutes, + IncludeNewRoutes: rs.includeNewRoutes, + KnownRoutes: rs.knownRoutes, }) } @@ -149,12 +184,16 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error { if len(data) == 0 || string(data) == "null" { rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectAll = true + rs.includeNewRoutes = false + rs.knownRoutes = []route.NetID{} return nil } var temp struct { - SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` - SelectAll bool `json:"select_all"` + SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` + SelectAll bool `json:"select_all"` + IncludeNewRoutes bool `json:"include_new_routes"` + KnownRoutes []route.NetID `json:"known_routes"` } if err := json.Unmarshal(data, &temp); err != nil { @@ -163,10 +202,15 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error { rs.selectedRoutes = temp.SelectedRoutes rs.selectAll = temp.SelectAll + rs.includeNewRoutes = temp.IncludeNewRoutes + rs.knownRoutes = temp.KnownRoutes if rs.selectedRoutes == nil { rs.selectedRoutes = map[route.NetID]struct{}{} } + if rs.knownRoutes == nil { + rs.knownRoutes = []route.NetID{} + } return nil } diff --git a/client/internal/routeselector/routeselector_test.go b/client/internal/routeselector/routeselector_test.go index b1671f254..a1461dff6 100644 --- a/client/internal/routeselector/routeselector_test.go +++ b/client/internal/routeselector/routeselector_test.go @@ -316,7 +316,7 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) { return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes) }, // After deselecting specific routes, new routes should remain unselected - wantNewSelected: []route.NetID{"route2", "route3"}, + wantNewSelected: []route.NetID{"route2", "route3", "route4", "route5"}, }, { name: "New routes after selecting with append", @@ -358,3 +358,73 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) { }) } } + +func TestRouteSelector_MixedSelectionDeselection(t *testing.T) { + allRoutes := []route.NetID{"route1", "route2", "route3"} + + tests := []struct { + name string + routesToSelect []route.NetID + selectAppend bool + routesToDeselect []route.NetID + selectFirst bool + wantSelectedFinal []route.NetID + }{ + { + name: "1. Select A, then Deselect B", + routesToSelect: []route.NetID{"route1"}, + selectAppend: false, + routesToDeselect: []route.NetID{"route2"}, + selectFirst: true, + wantSelectedFinal: []route.NetID{"route1"}, + }, + { + name: "2. Select A, then Deselect A", + routesToSelect: []route.NetID{"route1"}, + selectAppend: false, + routesToDeselect: []route.NetID{"route1"}, + selectFirst: true, + wantSelectedFinal: []route.NetID{}, + }, + { + name: "3. Deselect A (from all), then Select B", + routesToSelect: []route.NetID{"route2"}, + selectAppend: false, + routesToDeselect: []route.NetID{"route1"}, + selectFirst: false, + wantSelectedFinal: []route.NetID{"route2"}, + }, + { + name: "4. Deselect A (from all), then Select A", + routesToSelect: []route.NetID{"route1"}, + selectAppend: false, + routesToDeselect: []route.NetID{"route1"}, + selectFirst: false, + wantSelectedFinal: []route.NetID{"route1"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := routeselector.NewRouteSelector() + + var err1, err2 error + + if tt.selectFirst { + err1 = rs.SelectRoutes(tt.routesToSelect, tt.selectAppend, allRoutes) + require.NoError(t, err1) + err2 = rs.DeselectRoutes(tt.routesToDeselect, allRoutes) + require.NoError(t, err2) + } else { + err1 = rs.DeselectRoutes(tt.routesToDeselect, allRoutes) + require.NoError(t, err1) + err2 = rs.SelectRoutes(tt.routesToSelect, tt.selectAppend, allRoutes) + require.NoError(t, err2) + } + + for _, r := range allRoutes { + assert.Equal(t, slices.Contains(tt.wantSelectedFinal, r), rs.IsSelected(r), "Route %s final state mismatch", r) + } + }) + } +}