mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 17:58:02 +02:00
[client] Keep selecting new networks after first deselection (#3671)
This commit is contained in:
parent
a675531b5c
commit
0c93bd3d06
@ -10,20 +10,27 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/errors"
|
"github.com/netbirdio/netbird/client/errors"
|
||||||
route "github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RouteSelector struct {
|
type RouteSelector struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
selectedRoutes map[route.NetID]struct{}
|
selectedRoutes map[route.NetID]struct{}
|
||||||
selectAll bool
|
selectAll bool
|
||||||
|
|
||||||
|
// Indicates if new routes should be automatically selected
|
||||||
|
includeNewRoutes bool
|
||||||
|
|
||||||
|
// All known routes at the time of deselection
|
||||||
|
knownRoutes []route.NetID
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRouteSelector() *RouteSelector {
|
func NewRouteSelector() *RouteSelector {
|
||||||
return &RouteSelector{
|
return &RouteSelector{
|
||||||
selectedRoutes: map[route.NetID]struct{}{},
|
selectedRoutes: map[route.NetID]struct{}{},
|
||||||
// default selects all routes
|
|
||||||
selectAll: true,
|
selectAll: true,
|
||||||
|
includeNewRoutes: false,
|
||||||
|
knownRoutes: []route.NetID{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,6 +53,7 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
|
|||||||
rs.selectedRoutes[route] = struct{}{}
|
rs.selectedRoutes[route] = struct{}{}
|
||||||
}
|
}
|
||||||
rs.selectAll = false
|
rs.selectAll = false
|
||||||
|
rs.includeNewRoutes = false
|
||||||
|
|
||||||
return errors.FormatErrorOrNil(err)
|
return errors.FormatErrorOrNil(err)
|
||||||
}
|
}
|
||||||
@ -57,16 +65,22 @@ func (rs *RouteSelector) SelectAllRoutes() {
|
|||||||
|
|
||||||
rs.selectAll = true
|
rs.selectAll = true
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
|
rs.includeNewRoutes = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeselectRoutes removes specific routes from the selection.
|
// DeselectRoutes removes specific routes from the selection.
|
||||||
// If the selector is in "select all" mode, it will transition to "select specific" mode.
|
// If the selector is in "select all" mode, it will transition to "select specific" mode
|
||||||
|
// but will keep new routes selected.
|
||||||
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
|
func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error {
|
||||||
rs.mu.Lock()
|
rs.mu.Lock()
|
||||||
defer rs.mu.Unlock()
|
defer rs.mu.Unlock()
|
||||||
|
|
||||||
if rs.selectAll {
|
if rs.selectAll {
|
||||||
rs.selectAll = false
|
rs.selectAll = false
|
||||||
|
rs.includeNewRoutes = true
|
||||||
|
rs.knownRoutes = make([]route.NetID, len(allRoutes))
|
||||||
|
copy(rs.knownRoutes, allRoutes)
|
||||||
|
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
for _, route := range allRoutes {
|
for _, route := range allRoutes {
|
||||||
rs.selectedRoutes[route] = struct{}{}
|
rs.selectedRoutes[route] = struct{}{}
|
||||||
@ -92,6 +106,7 @@ func (rs *RouteSelector) DeselectAllRoutes() {
|
|||||||
defer rs.mu.Unlock()
|
defer rs.mu.Unlock()
|
||||||
|
|
||||||
rs.selectAll = false
|
rs.selectAll = false
|
||||||
|
rs.includeNewRoutes = false
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -103,8 +118,20 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
|||||||
if rs.selectAll {
|
if rs.selectAll {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if the route exists in selectedRoutes
|
||||||
_, selected := rs.selectedRoutes[routeID]
|
_, selected := rs.selectedRoutes[routeID]
|
||||||
return selected
|
if selected {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// If includeNewRoutes is true and this is a new route (not in knownRoutes),
|
||||||
|
// then it should be selected
|
||||||
|
if rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, routeID) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilterSelected removes unselected routes from the provided map.
|
// FilterSelected removes unselected routes from the provided map.
|
||||||
@ -118,7 +145,11 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
|
|||||||
|
|
||||||
filtered := route.HAMap{}
|
filtered := route.HAMap{}
|
||||||
for id, rt := range routes {
|
for id, rt := range routes {
|
||||||
if rs.IsSelected(id.NetID()) {
|
netID := id.NetID()
|
||||||
|
_, selected := rs.selectedRoutes[netID]
|
||||||
|
|
||||||
|
// Include if directly selected or if it's a new route and includeNewRoutes is true
|
||||||
|
if selected || (rs.includeNewRoutes && !slices.Contains(rs.knownRoutes, netID)) {
|
||||||
filtered[id] = rt
|
filtered[id] = rt
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -133,9 +164,13 @@ func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
|
|||||||
return json.Marshal(struct {
|
return json.Marshal(struct {
|
||||||
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
|
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
|
||||||
SelectAll bool `json:"select_all"`
|
SelectAll bool `json:"select_all"`
|
||||||
|
IncludeNewRoutes bool `json:"include_new_routes"`
|
||||||
|
KnownRoutes []route.NetID `json:"known_routes"`
|
||||||
}{
|
}{
|
||||||
SelectAll: rs.selectAll,
|
SelectAll: rs.selectAll,
|
||||||
SelectedRoutes: rs.selectedRoutes,
|
SelectedRoutes: rs.selectedRoutes,
|
||||||
|
IncludeNewRoutes: rs.includeNewRoutes,
|
||||||
|
KnownRoutes: rs.knownRoutes,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -149,12 +184,16 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
|
|||||||
if len(data) == 0 || string(data) == "null" {
|
if len(data) == 0 || string(data) == "null" {
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
rs.selectAll = true
|
rs.selectAll = true
|
||||||
|
rs.includeNewRoutes = false
|
||||||
|
rs.knownRoutes = []route.NetID{}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var temp struct {
|
var temp struct {
|
||||||
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
|
SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"`
|
||||||
SelectAll bool `json:"select_all"`
|
SelectAll bool `json:"select_all"`
|
||||||
|
IncludeNewRoutes bool `json:"include_new_routes"`
|
||||||
|
KnownRoutes []route.NetID `json:"known_routes"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(data, &temp); err != nil {
|
if err := json.Unmarshal(data, &temp); err != nil {
|
||||||
@ -163,10 +202,15 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
|
|||||||
|
|
||||||
rs.selectedRoutes = temp.SelectedRoutes
|
rs.selectedRoutes = temp.SelectedRoutes
|
||||||
rs.selectAll = temp.SelectAll
|
rs.selectAll = temp.SelectAll
|
||||||
|
rs.includeNewRoutes = temp.IncludeNewRoutes
|
||||||
|
rs.knownRoutes = temp.KnownRoutes
|
||||||
|
|
||||||
if rs.selectedRoutes == nil {
|
if rs.selectedRoutes == nil {
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
}
|
}
|
||||||
|
if rs.knownRoutes == nil {
|
||||||
|
rs.knownRoutes = []route.NetID{}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -316,7 +316,7 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
|
|||||||
return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes)
|
return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes)
|
||||||
},
|
},
|
||||||
// After deselecting specific routes, new routes should remain unselected
|
// After deselecting specific routes, new routes should remain unselected
|
||||||
wantNewSelected: []route.NetID{"route2", "route3"},
|
wantNewSelected: []route.NetID{"route2", "route3", "route4", "route5"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "New routes after selecting with append",
|
name: "New routes after selecting with append",
|
||||||
@ -358,3 +358,73 @@ func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRouteSelector_MixedSelectionDeselection(t *testing.T) {
|
||||||
|
allRoutes := []route.NetID{"route1", "route2", "route3"}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
routesToSelect []route.NetID
|
||||||
|
selectAppend bool
|
||||||
|
routesToDeselect []route.NetID
|
||||||
|
selectFirst bool
|
||||||
|
wantSelectedFinal []route.NetID
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "1. Select A, then Deselect B",
|
||||||
|
routesToSelect: []route.NetID{"route1"},
|
||||||
|
selectAppend: false,
|
||||||
|
routesToDeselect: []route.NetID{"route2"},
|
||||||
|
selectFirst: true,
|
||||||
|
wantSelectedFinal: []route.NetID{"route1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "2. Select A, then Deselect A",
|
||||||
|
routesToSelect: []route.NetID{"route1"},
|
||||||
|
selectAppend: false,
|
||||||
|
routesToDeselect: []route.NetID{"route1"},
|
||||||
|
selectFirst: true,
|
||||||
|
wantSelectedFinal: []route.NetID{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "3. Deselect A (from all), then Select B",
|
||||||
|
routesToSelect: []route.NetID{"route2"},
|
||||||
|
selectAppend: false,
|
||||||
|
routesToDeselect: []route.NetID{"route1"},
|
||||||
|
selectFirst: false,
|
||||||
|
wantSelectedFinal: []route.NetID{"route2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "4. Deselect A (from all), then Select A",
|
||||||
|
routesToSelect: []route.NetID{"route1"},
|
||||||
|
selectAppend: false,
|
||||||
|
routesToDeselect: []route.NetID{"route1"},
|
||||||
|
selectFirst: false,
|
||||||
|
wantSelectedFinal: []route.NetID{"route1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rs := routeselector.NewRouteSelector()
|
||||||
|
|
||||||
|
var err1, err2 error
|
||||||
|
|
||||||
|
if tt.selectFirst {
|
||||||
|
err1 = rs.SelectRoutes(tt.routesToSelect, tt.selectAppend, allRoutes)
|
||||||
|
require.NoError(t, err1)
|
||||||
|
err2 = rs.DeselectRoutes(tt.routesToDeselect, allRoutes)
|
||||||
|
require.NoError(t, err2)
|
||||||
|
} else {
|
||||||
|
err1 = rs.DeselectRoutes(tt.routesToDeselect, allRoutes)
|
||||||
|
require.NoError(t, err1)
|
||||||
|
err2 = rs.SelectRoutes(tt.routesToSelect, tt.selectAppend, allRoutes)
|
||||||
|
require.NoError(t, err2)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range allRoutes {
|
||||||
|
assert.Equal(t, slices.Contains(tt.wantSelectedFinal, r), rs.IsSelected(r), "Route %s final state mismatch", r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user