[client] Keep selecting new networks after first deselection (#3671)

This commit is contained in:
Viktor Liu 2025-04-16 13:55:26 +02:00 committed by GitHub
parent a675531b5c
commit 0c93bd3d06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 128 additions and 14 deletions

View File

@ -10,20 +10,27 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/errors"
route "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
type RouteSelector struct { type RouteSelector struct {
mu sync.RWMutex mu sync.RWMutex
selectedRoutes map[route.NetID]struct{} selectedRoutes map[route.NetID]struct{}
selectAll bool selectAll bool
// Indicates if new routes should be automatically selected
includeNewRoutes bool
// All known routes at the time of deselection
knownRoutes []route.NetID
} }
func NewRouteSelector() *RouteSelector { func NewRouteSelector() *RouteSelector {
return &RouteSelector{ return &RouteSelector{
selectedRoutes: map[route.NetID]struct{}{}, selectedRoutes: map[route.NetID]struct{}{},
// default selects all routes
selectAll: true, selectAll: true,
includeNewRoutes: false,
knownRoutes: []route.NetID{},
} }
} }
@ -46,6 +53,7 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
rs.selectedRoutes[route] = struct{}{} rs.selectedRoutes[route] = struct{}{}
} }
rs.selectAll = false rs.selectAll = false
rs.includeNewRoutes = false
return errors.FormatErrorOrNil(err) return errors.FormatErrorOrNil(err)
} }
@ -57,16 +65,22 @@ func (rs *RouteSelector) SelectAllRoutes() {
rs.selectAll = true rs.selectAll = true
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
rs.includeNewRoutes = false
} }
// DeselectRoutes removes specific routes from the selection. // DeselectRoutes removes specific routes from the selection.
// If the selector is in "select all" mode, it will transition to "select specific" mode. // If the selector is in "select all" mode, it will transition to "select specific" mode
// but will keep new routes selected.
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error { func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
rs.mu.Lock() rs.mu.Lock()
defer rs.mu.Unlock() defer rs.mu.Unlock()
if rs.selectAll { if rs.selectAll {
rs.selectAll = false rs.selectAll = false
rs.includeNewRoutes = true
rs.knownRoutes = make([]route.NetID, len(allRoutes))
copy(rs.knownRoutes, allRoutes)
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
for _, route := range allRoutes { for _, route := range allRoutes {
rs.selectedRoutes[route] = struct{}{} rs.selectedRoutes[route] = struct{}{}
@ -92,6 +106,7 @@ func (rs *RouteSelector) DeselectAllRoutes() {
defer rs.mu.Unlock() defer rs.mu.Unlock()
rs.selectAll = false rs.selectAll = false
rs.includeNewRoutes = false
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
@ -103,8 +118,20 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
if rs.selectAll { if rs.selectAll {
return true return true
} }
// Check if the route exists in selectedRoutes
_, selected := rs.selectedRoutes[routeID] _, selected := rs.selectedRoutes[routeID]
return selected if selected {
return true
}
// If includeNewRoutes is true and this is a new route (not in knownRoutes),
// then it should be selected
if rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, routeID) {
return true
}
return false
} }
// FilterSelected removes unselected routes from the provided map. // FilterSelected removes unselected routes from the provided map.
@ -118,7 +145,11 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
filtered := route.HAMap{} filtered := route.HAMap{}
for id, rt := range routes { for id, rt := range routes {
if rs.IsSelected(id.NetID()) { netID := id.NetID()
_, selected := rs.selectedRoutes[netID]
// Include if directly selected or if it's a new route and includeNewRoutes is true
if selected || (rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, netID)) {
filtered[id] = rt filtered[id] = rt
} }
} }
@ -133,9 +164,13 @@ func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
return json.Marshal(struct { return json.Marshal(struct {
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
SelectAll bool `json:"select_all"` SelectAll bool `json:"select_all"`
IncludeNewRoutes bool `json:"include_new_routes"`
KnownRoutes []route.NetID `json:"known_routes"`
}{ }{
SelectAll: rs.selectAll, SelectAll: rs.selectAll,
SelectedRoutes: rs.selectedRoutes, SelectedRoutes: rs.selectedRoutes,
IncludeNewRoutes: rs.includeNewRoutes,
KnownRoutes: rs.knownRoutes,
}) })
} }
@ -149,12 +184,16 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
if len(data) == 0 || string(data) == "null" { if len(data) == 0 || string(data) == "null" {
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
rs.selectAll = true rs.selectAll = true
rs.includeNewRoutes = false
rs.knownRoutes = []route.NetID{}
return nil return nil
} }
var temp struct { var temp struct {
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
SelectAll bool `json:"select_all"` SelectAll bool `json:"select_all"`
IncludeNewRoutes bool `json:"include_new_routes"`
KnownRoutes []route.NetID `json:"known_routes"`
} }
if err := json.Unmarshal(data, &temp); err != nil { if err := json.Unmarshal(data, &temp); err != nil {
@ -163,10 +202,15 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
rs.selectedRoutes = temp.SelectedRoutes rs.selectedRoutes = temp.SelectedRoutes
rs.selectAll = temp.SelectAll rs.selectAll = temp.SelectAll
rs.includeNewRoutes = temp.IncludeNewRoutes
rs.knownRoutes = temp.KnownRoutes
if rs.selectedRoutes == nil { if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
if rs.knownRoutes == nil {
rs.knownRoutes = []route.NetID{}
}
return nil return nil
} }

View File

@ -316,7 +316,7 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes) return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes)
}, },
// After deselecting specific routes, new routes should remain unselected // After deselecting specific routes, new routes should remain unselected
wantNewSelected: []route.NetID{"route2", "route3"}, wantNewSelected: []route.NetID{"route2", "route3", "route4", "route5"},
}, },
{ {
name: "New routes after selecting with append", name: "New routes after selecting with append",
@ -358,3 +358,73 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
}) })
} }
} }
func TestRouteSelector_MixedSelectionDeselection(t *testing.T) {
allRoutes := []route.NetID{"route1", "route2", "route3"}
tests := []struct {
name string
routesToSelect []route.NetID
selectAppend bool
routesToDeselect []route.NetID
selectFirst bool
wantSelectedFinal []route.NetID
}{
{
name: "1. Select A, then Deselect B",
routesToSelect: []route.NetID{"route1"},
selectAppend: false,
routesToDeselect: []route.NetID{"route2"},
selectFirst: true,
wantSelectedFinal: []route.NetID{"route1"},
},
{
name: "2. Select A, then Deselect A",
routesToSelect: []route.NetID{"route1"},
selectAppend: false,
routesToDeselect: []route.NetID{"route1"},
selectFirst: true,
wantSelectedFinal: []route.NetID{},
},
{
name: "3. Deselect A (from all), then Select B",
routesToSelect: []route.NetID{"route2"},
selectAppend: false,
routesToDeselect: []route.NetID{"route1"},
selectFirst: false,
wantSelectedFinal: []route.NetID{"route2"},
},
{
name: "4. Deselect A (from all), then Select A",
routesToSelect: []route.NetID{"route1"},
selectAppend: false,
routesToDeselect: []route.NetID{"route1"},
selectFirst: false,
wantSelectedFinal: []route.NetID{"route1"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
var err1, err2 error
if tt.selectFirst {
err1 = rs.SelectRoutes(tt.routesToSelect, tt.selectAppend, allRoutes)
require.NoError(t, err1)
err2 = rs.DeselectRoutes(tt.routesToDeselect, allRoutes)
require.NoError(t, err2)
} else {
err1 = rs.DeselectRoutes(tt.routesToDeselect, allRoutes)
require.NoError(t, err1)
err2 = rs.SelectRoutes(tt.routesToSelect, tt.selectAppend, allRoutes)
require.NoError(t, err2)
}
for _, r := range allRoutes {
assert.Equal(t, slices.Contains(tt.wantSelectedFinal, r), rs.IsSelected(r), "Route %s final state mismatch", r)
}
})
}
}