[client] Keep new routes selected unless all are deselected (#3692)

This commit is contained in:
Viktor Liu 2025-04-23 01:07:04 +02:00 committed by GitHub
parent 1a6d6b3109
commit 3b7b9d25bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 260 additions and 96 deletions

View File

@ -15,22 +15,14 @@ import (
type RouteSelector struct { type RouteSelector struct {
mu sync.RWMutex mu sync.RWMutex
selectedRoutes map[route.NetID]struct{} deselectedRoutes map[route.NetID]struct{}
selectAll bool deselectAll bool
// Indicates if new routes should be automatically selected
includeNewRoutes bool
// All known routes at the time of deselection
knownRoutes []route.NetID
} }
func NewRouteSelector() *RouteSelector { func NewRouteSelector() *RouteSelector {
return &RouteSelector{ return &RouteSelector{
selectedRoutes: map[route.NetID]struct{}{}, deselectedRoutes: map[route.NetID]struct{}{},
selectAll: true, deselectAll: false,
includeNewRoutes: false,
knownRoutes: []route.NetID{},
} }
} }
@ -39,8 +31,11 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
rs.mu.Lock() rs.mu.Lock()
defer rs.mu.Unlock() defer rs.mu.Unlock()
if !appendRoute { if !appendRoute || rs.deselectAll {
rs.selectedRoutes = map[route.NetID]struct{}{} maps.Clear(rs.deselectedRoutes)
for _, r := range allRoutes {
rs.deselectedRoutes[r] = struct{}{}
}
} }
var err *multierror.Error var err *multierror.Error
@ -49,11 +44,10 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route)) err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
continue continue
} }
delete(rs.deselectedRoutes, route)
rs.selectedRoutes[route] = struct{}{}
} }
rs.selectAll = false
rs.includeNewRoutes = false rs.deselectAll = false
return errors.FormatErrorOrNil(err) return errors.FormatErrorOrNil(err)
} }
@ -63,38 +57,26 @@ func (rs *RouteSelector) SelectAllRoutes() {
rs.mu.Lock() rs.mu.Lock()
defer rs.mu.Unlock() defer rs.mu.Unlock()
rs.selectAll = true rs.deselectAll = false
rs.selectedRoutes = map[route.NetID]struct{}{} maps.Clear(rs.deselectedRoutes)
rs.includeNewRoutes = false
} }
// DeselectRoutes removes specific routes from the selection. // DeselectRoutes removes specific routes from the selection.
// If the selector is in "select all" mode, it will transition to "select specific" mode
// but will keep new routes selected.
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error { func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
rs.mu.Lock() rs.mu.Lock()
defer rs.mu.Unlock() defer rs.mu.Unlock()
if rs.selectAll { if rs.deselectAll {
rs.selectAll = false return nil
rs.includeNewRoutes = true
rs.knownRoutes = make([]route.NetID, len(allRoutes))
copy(rs.knownRoutes, allRoutes)
rs.selectedRoutes = map[route.NetID]struct{}{}
for _, route := range allRoutes {
rs.selectedRoutes[route] = struct{}{}
}
} }
var err *multierror.Error var err *multierror.Error
for _, route := range routes { for _, route := range routes {
if !slices.Contains(allRoutes, route) { if !slices.Contains(allRoutes, route) {
err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route)) err = multierror.Append(err, fmt.Errorf("route '%s' is not available", route))
continue continue
} }
delete(rs.selectedRoutes, route) rs.deselectedRoutes[route] = struct{}{}
} }
return errors.FormatErrorOrNil(err) return errors.FormatErrorOrNil(err)
@ -105,9 +87,8 @@ func (rs *RouteSelector) DeselectAllRoutes() {
rs.mu.Lock() rs.mu.Lock()
defer rs.mu.Unlock() defer rs.mu.Unlock()
rs.selectAll = false rs.deselectAll = true
rs.includeNewRoutes = false maps.Clear(rs.deselectedRoutes)
rs.selectedRoutes = map[route.NetID]struct{}{}
} }
// IsSelected checks if a specific route is selected. // IsSelected checks if a specific route is selected.
@ -115,41 +96,28 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
rs.mu.RLock() rs.mu.RLock()
defer rs.mu.RUnlock() defer rs.mu.RUnlock()
if rs.selectAll { if rs.deselectAll {
return true
}
// Check if the route exists in selectedRoutes
_, selected := rs.selectedRoutes[routeID]
if selected {
return true
}
// If includeNewRoutes is true and this is a new route (not in knownRoutes),
// then it should be selected
if rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, routeID) {
return true
}
return false return false
} }
_, deselected := rs.deselectedRoutes[routeID]
return !deselected
}
// FilterSelected removes unselected routes from the provided map. // FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
rs.mu.RLock() rs.mu.RLock()
defer rs.mu.RUnlock() defer rs.mu.RUnlock()
if rs.selectAll { if rs.deselectAll {
return maps.Clone(routes) return route.HAMap{}
} }
filtered := route.HAMap{} filtered := route.HAMap{}
for id, rt := range routes { for id, rt := range routes {
netID := id.NetID() netID := id.NetID()
_, selected := rs.selectedRoutes[netID] _, deselected := rs.deselectedRoutes[netID]
if !deselected {
// Include if directly selected or if it's a new route and includeNewRoutes is true
if selected || (rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, netID)) {
filtered[id] = rt filtered[id] = rt
} }
} }
@ -162,15 +130,11 @@ func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
defer rs.mu.RUnlock() defer rs.mu.RUnlock()
return json.Marshal(struct { return json.Marshal(struct {
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"`
SelectAll bool `json:"select_all"` DeselectAll bool `json:"deselect_all"`
IncludeNewRoutes bool `json:"include_new_routes"`
KnownRoutes []route.NetID `json:"known_routes"`
}{ }{
SelectAll: rs.selectAll, DeselectedRoutes: rs.deselectedRoutes,
SelectedRoutes: rs.selectedRoutes, DeselectAll: rs.deselectAll,
IncludeNewRoutes: rs.includeNewRoutes,
KnownRoutes: rs.knownRoutes,
}) })
} }
@ -182,34 +146,25 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
// Check for null or empty JSON // Check for null or empty JSON
if len(data) == 0 || string(data) == "null" { if len(data) == 0 || string(data) == "null" {
rs.selectedRoutes = map[route.NetID]struct{}{} rs.deselectedRoutes = map[route.NetID]struct{}{}
rs.selectAll = true rs.deselectAll = false
rs.includeNewRoutes = false
rs.knownRoutes = []route.NetID{}
return nil return nil
} }
var temp struct { var temp struct {
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` DeselectedRoutes map[route.NetID]struct{} `json:"deselected_routes"`
SelectAll bool `json:"select_all"` DeselectAll bool `json:"deselect_all"`
IncludeNewRoutes bool `json:"include_new_routes"`
KnownRoutes []route.NetID `json:"known_routes"`
} }
if err := json.Unmarshal(data, &temp); err != nil { if err := json.Unmarshal(data, &temp); err != nil {
return err return err
} }
rs.selectedRoutes = temp.SelectedRoutes rs.deselectedRoutes = temp.DeselectedRoutes
rs.selectAll = temp.SelectAll rs.deselectAll = temp.DeselectAll
rs.includeNewRoutes = temp.IncludeNewRoutes
rs.knownRoutes = temp.KnownRoutes
if rs.selectedRoutes == nil { if rs.deselectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{} rs.deselectedRoutes = map[route.NetID]struct{}{}
}
if rs.knownRoutes == nil {
rs.knownRoutes = []route.NetID{}
} }
return nil return nil

View File

@ -66,12 +66,10 @@ func TestRouteSelector_SelectRoutes(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector() rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes) err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err) require.NoError(t, err)
}
err := rs.SelectRoutes(tt.selectRoutes, tt.append, allRoutes) err = rs.SelectRoutes(tt.selectRoutes, tt.append, allRoutes)
if tt.wantError { if tt.wantError {
assert.Error(t, err) assert.Error(t, err)
} else { } else {
@ -251,7 +249,8 @@ func TestRouteSelector_IsSelected(t *testing.T) {
assert.True(t, rs.IsSelected("route1")) assert.True(t, rs.IsSelected("route1"))
assert.True(t, rs.IsSelected("route2")) assert.True(t, rs.IsSelected("route2"))
assert.False(t, rs.IsSelected("route3")) assert.False(t, rs.IsSelected("route3"))
assert.False(t, rs.IsSelected("route4")) // Unknown route is selected by default
assert.True(t, rs.IsSelected("route4"))
} }
func TestRouteSelector_FilterSelected(t *testing.T) { func TestRouteSelector_FilterSelected(t *testing.T) {
@ -297,8 +296,8 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
initialState: func(rs *routeselector.RouteSelector) error { initialState: func(rs *routeselector.RouteSelector) error {
return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, initialRoutes) return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, initialRoutes)
}, },
// When specific routes were selected, new routes should remain unselected // When specific routes were selected, new routes should be selected
wantNewSelected: []route.NetID{"route1", "route2"}, wantNewSelected: []route.NetID{"route1", "route2", "route4", "route5"},
}, },
{ {
name: "New routes after deselect all", name: "New routes after deselect all",
@ -315,7 +314,7 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
rs.SelectAllRoutes() rs.SelectAllRoutes()
return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes) return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes)
}, },
// After deselecting specific routes, new routes should remain unselected // After deselecting specific routes, new routes should be selected
wantNewSelected: []route.NetID{"route2", "route3", "route4", "route5"}, wantNewSelected: []route.NetID{"route2", "route3", "route4", "route5"},
}, },
{ {
@ -323,8 +322,8 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
initialState: func(rs *routeselector.RouteSelector) error { initialState: func(rs *routeselector.RouteSelector) error {
return rs.SelectRoutes([]route.NetID{"route1"}, true, initialRoutes) return rs.SelectRoutes([]route.NetID{"route1"}, true, initialRoutes)
}, },
// When routes were appended, new routes should remain unselected // When routes were appended, new routes should be selected
wantNewSelected: []route.NetID{"route1"}, wantNewSelected: []route.NetID{"route1", "route2", "route3", "route4", "route5"},
}, },
} }
@ -428,3 +427,213 @@ func TestRouteSelector_MixedSelectionDeselection(t *testing.T) {
}) })
} }
} }
func TestRouteSelector_AfterDeselectAll(t *testing.T) {
allRoutes := []route.NetID{"route1", "route2", "route3"}
tests := []struct {
name string
initialAction func(rs *routeselector.RouteSelector) error
secondAction func(rs *routeselector.RouteSelector) error
wantSelected []route.NetID
wantError bool
}{
{
name: "Deselect all -> select specific routes",
initialAction: func(rs *routeselector.RouteSelector) error {
rs.DeselectAllRoutes()
return nil
},
secondAction: func(rs *routeselector.RouteSelector) error {
return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, allRoutes)
},
wantSelected: []route.NetID{"route1", "route2"},
},
{
name: "Deselect all -> select with append",
initialAction: func(rs *routeselector.RouteSelector) error {
rs.DeselectAllRoutes()
return nil
},
secondAction: func(rs *routeselector.RouteSelector) error {
return rs.SelectRoutes([]route.NetID{"route1"}, true, allRoutes)
},
wantSelected: []route.NetID{"route1"},
},
{
name: "Deselect all -> deselect specific",
initialAction: func(rs *routeselector.RouteSelector) error {
rs.DeselectAllRoutes()
return nil
},
secondAction: func(rs *routeselector.RouteSelector) error {
return rs.DeselectRoutes([]route.NetID{"route1"}, allRoutes)
},
wantSelected: []route.NetID{},
},
{
name: "Deselect all -> select all",
initialAction: func(rs *routeselector.RouteSelector) error {
rs.DeselectAllRoutes()
return nil
},
secondAction: func(rs *routeselector.RouteSelector) error {
rs.SelectAllRoutes()
return nil
},
wantSelected: []route.NetID{"route1", "route2", "route3"},
},
{
name: "Deselect all -> deselect non-existent route",
initialAction: func(rs *routeselector.RouteSelector) error {
rs.DeselectAllRoutes()
return nil
},
secondAction: func(rs *routeselector.RouteSelector) error {
return rs.DeselectRoutes([]route.NetID{"route4"}, allRoutes)
},
wantSelected: []route.NetID{},
wantError: false,
},
{
name: "Select specific -> deselect all -> select different",
initialAction: func(rs *routeselector.RouteSelector) error {
err := rs.SelectRoutes([]route.NetID{"route1"}, false, allRoutes)
if err != nil {
return err
}
rs.DeselectAllRoutes()
return nil
},
secondAction: func(rs *routeselector.RouteSelector) error {
return rs.SelectRoutes([]route.NetID{"route2", "route3"}, false, allRoutes)
},
wantSelected: []route.NetID{"route2", "route3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
err := tt.initialAction(rs)
require.NoError(t, err)
err = tt.secondAction(rs)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
for _, id := range allRoutes {
expected := slices.Contains(tt.wantSelected, id)
assert.Equal(t, expected, rs.IsSelected(id),
"Route %s selection state incorrect, expected %v", id, expected)
}
routes := route.HAMap{
"route1|10.0.0.0/8": {},
"route2|192.168.0.0/16": {},
"route3|172.16.0.0/12": {},
}
filtered := rs.FilterSelected(routes)
assert.Equal(t, len(tt.wantSelected), len(filtered),
"FilterSelected returned wrong number of routes")
})
}
}
func TestRouteSelector_ComplexScenarios(t *testing.T) {
allRoutes := []route.NetID{"route1", "route2", "route3", "route4"}
tests := []struct {
name string
actions []func(rs *routeselector.RouteSelector) error
wantSelected []route.NetID
}{
{
name: "Select all -> deselect specific -> select different with append",
actions: []func(rs *routeselector.RouteSelector) error{
func(rs *routeselector.RouteSelector) error {
rs.SelectAllRoutes()
return nil
},
func(rs *routeselector.RouteSelector) error {
return rs.DeselectRoutes([]route.NetID{"route1", "route2"}, allRoutes)
},
func(rs *routeselector.RouteSelector) error {
return rs.SelectRoutes([]route.NetID{"route1"}, true, allRoutes)
},
},
wantSelected: []route.NetID{"route1", "route3", "route4"},
},
{
name: "Deselect all -> select specific -> deselect one -> select different with append",
actions: []func(rs *routeselector.RouteSelector) error{
func(rs *routeselector.RouteSelector) error {
rs.DeselectAllRoutes()
return nil
},
func(rs *routeselector.RouteSelector) error {
return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, allRoutes)
},
func(rs *routeselector.RouteSelector) error {
return rs.DeselectRoutes([]route.NetID{"route2"}, allRoutes)
},
func(rs *routeselector.RouteSelector) error {
return rs.SelectRoutes([]route.NetID{"route3"}, true, allRoutes)
},
},
wantSelected: []route.NetID{"route1", "route3"},
},
{
name: "Select specific -> deselect specific -> select all -> deselect different",
actions: []func(rs *routeselector.RouteSelector) error{
func(rs *routeselector.RouteSelector) error {
return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, allRoutes)
},
func(rs *routeselector.RouteSelector) error {
return rs.DeselectRoutes([]route.NetID{"route2"}, allRoutes)
},
func(rs *routeselector.RouteSelector) error {
rs.SelectAllRoutes()
return nil
},
func(rs *routeselector.RouteSelector) error {
return rs.DeselectRoutes([]route.NetID{"route3", "route4"}, allRoutes)
},
},
wantSelected: []route.NetID{"route1", "route2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
for i, action := range tt.actions {
err := action(rs)
require.NoError(t, err, "Action %d failed", i)
}
for _, id := range allRoutes {
expected := slices.Contains(tt.wantSelected, id)
assert.Equal(t, expected, rs.IsSelected(id),
"Route %s selection state incorrect", id)
}
routes := route.HAMap{
"route1|10.0.0.0/8": {},
"route2|192.168.0.0/16": {},
"route3|172.16.0.0/12": {},
"route4|10.10.0.0/16": {},
}
filtered := rs.FilterSelected(routes)
assert.Equal(t, len(tt.wantSelected), len(filtered),
"FilterSelected returned wrong number of routes")
})
}
}