From e586eca16cf4aa95b18afd4585e042cc693a5aa9 Mon Sep 17 00:00:00 2001 From: Yury Gargay Date: Tue, 22 Aug 2023 17:56:39 +0200 Subject: [PATCH] Improve account copying (#1069) With this fix, all nested slices and pointers will be copied by value. Also, this fixes tests to compare the original and copy account by their values by marshaling them to JSON strings. Before that, they were copying the pointers that also passed the simple `=` compassion (as the addresses match). --- dns/nameserver.go | 14 +++++--- management/server/account_test.go | 34 ++++++++++++++++--- management/server/group.go | 6 ++-- .../server/http/nameservers_handler_test.go | 1 + management/server/peer.go | 8 +++-- management/server/personal_access_token.go | 12 +++++++ management/server/policy.go | 17 ++++++---- management/server/rule.go | 9 +++-- management/server/setupkey.go | 4 +-- management/server/user.go | 4 +-- route/route.go | 9 +++-- 11 files changed, 88 insertions(+), 30 deletions(-) diff --git a/dns/nameserver.go b/dns/nameserver.go index 807df5907..7751f8e1c 100644 --- a/dns/nameserver.go +++ b/dns/nameserver.go @@ -130,16 +130,22 @@ func ParseNameServerURL(nsURL string) (NameServer, error) { // Copy copies a nameserver group object func (g *NameServerGroup) Copy() *NameServerGroup { - return &NameServerGroup{ + nsGroup := &NameServerGroup{ ID: g.ID, Name: g.Name, Description: g.Description, - NameServers: g.NameServers, - Groups: g.Groups, + NameServers: make([]NameServer, len(g.NameServers)), + Groups: make([]string, len(g.Groups)), Enabled: g.Enabled, Primary: g.Primary, - Domains: g.Domains, + Domains: make([]string, len(g.Domains)), } + + copy(nsGroup.NameServers, g.NameServers) + copy(nsGroup.Groups, g.Groups) + copy(nsGroup.Domains, g.Domains) + + return nsGroup } // IsEqual compares one nameserver group with the other diff --git a/management/server/account_test.go b/management/server/account_test.go index 119828e20..29af8514a 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3,6 +3,7 @@ package server import ( "crypto/sha256" b64 "encoding/base64" + "encoding/json" "fmt" "net" "reflect" @@ -1348,6 +1349,11 @@ func TestAccount_Copy(t *testing.T) { Peers: map[string]*Peer{ "peer1": { Key: "key1", + Status: &PeerStatus{ + LastSeen: time.Now(), + Connected: true, + LoginExpired: false, + }, }, }, Users: map[string]*User{ @@ -1370,28 +1376,36 @@ func TestAccount_Copy(t *testing.T) { }, Groups: map[string]*Group{ "group1": { - ID: "group1", + ID: "group1", + Peers: []string{"peer1"}, }, }, Rules: map[string]*Rule{ "rule1": { - ID: "rule1", + ID: "rule1", + Destination: []string{}, + Source: []string{}, }, }, Policies: []*Policy{ { ID: "policy1", Enabled: true, + Rules: make([]*PolicyRule, 0), }, }, Routes: map[string]*route.Route{ "route1": { - ID: "route1", + ID: "route1", + Groups: []string{"group1"}, }, }, NameServerGroups: map[string]*nbdns.NameServerGroup{ "nsGroup1": { - ID: "nsGroup1", + ID: "nsGroup1", + Domains: []string{}, + Groups: []string{}, + NameServers: []nbdns.NameServer{}, }, }, DNSSettings: &DNSSettings{DisabledManagementGroups: []string{}}, @@ -1402,10 +1416,20 @@ func TestAccount_Copy(t *testing.T) { t.Fatal(err) } accountCopy := account.Copy() - assert.Equal(t, account, accountCopy, "account copy returned a different value than expected") + accBytes, err := json.Marshal(account) + if err != nil { + t.Fatal(err) + } + account.Peers["peer1"].Status.Connected = false // we change original object to confirm that copy wont change + accCopyBytes, err := json.Marshal(accountCopy) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, string(accBytes), string(accCopyBytes), "account copy returned a different value than expected") } // hasNilField validates pointers, maps and slices if they are nil +// TODO: make it check nested fields too func hasNilField(x interface{}) error { rv := reflect.ValueOf(x) rv = rv.Elem() diff --git a/management/server/group.go b/management/server/group.go index 53571e099..5b1d2ac9f 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -59,12 +59,14 @@ func (g *Group) EventMeta() map[string]any { } func (g *Group) Copy() *Group { - return &Group{ + group := &Group{ ID: g.ID, Name: g.Name, Issued: g.Issued, - Peers: g.Peers[:], + Peers: make([]string, len(g.Peers)), } + copy(group.Peers, g.Peers) + return group } // GetGroup object of the peers diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/nameservers_handler_test.go index 01c3cbe79..75fcb4c1c 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/nameservers_handler_test.go @@ -54,6 +54,7 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{ }, }, Groups: []string{"testing"}, + Domains: []string{"domain"}, Enabled: true, } diff --git a/management/server/peer.go b/management/server/peer.go index b2d4e436f..90377b1e8 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -108,6 +108,10 @@ func (p *Peer) AddedWithSSOLogin() bool { // Copy copies Peer object func (p *Peer) Copy() *Peer { + peerStatus := p.Status + if peerStatus != nil { + peerStatus = p.Status.Copy() + } return &Peer{ ID: p.ID, Key: p.Key, @@ -115,11 +119,11 @@ func (p *Peer) Copy() *Peer { IP: p.IP, Meta: p.Meta, Name: p.Name, - Status: p.Status, + DNSLabel: p.DNSLabel, + Status: peerStatus, UserID: p.UserID, SSHKey: p.SSHKey, SSHEnabled: p.SSHEnabled, - DNSLabel: p.DNSLabel, LoginExpirationEnabled: p.LoginExpirationEnabled, LastLogin: p.LastLogin, } diff --git a/management/server/personal_access_token.go b/management/server/personal_access_token.go index 0a55f3237..c7deca9de 100644 --- a/management/server/personal_access_token.go +++ b/management/server/personal_access_token.go @@ -36,6 +36,18 @@ type PersonalAccessToken struct { LastUsed time.Time } +func (t *PersonalAccessToken) Copy() *PersonalAccessToken { + return &PersonalAccessToken{ + ID: t.ID, + Name: t.Name, + HashedToken: t.HashedToken, + ExpirationDate: t.ExpirationDate, + CreatedBy: t.CreatedBy, + CreatedAt: t.CreatedAt, + LastUsed: t.LastUsed, + } +} + // PersonalAccessTokenGenerated holds the new PersonalAccessToken and the plain text version of it type PersonalAccessTokenGenerated struct { PlainToken string diff --git a/management/server/policy.go b/management/server/policy.go index 54158eeac..dde0b46d8 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -95,18 +95,22 @@ type PolicyRule struct { // Copy returns a copy of a policy rule func (pm *PolicyRule) Copy() *PolicyRule { - return &PolicyRule{ + rule := &PolicyRule{ ID: pm.ID, Name: pm.Name, Description: pm.Description, Enabled: pm.Enabled, Action: pm.Action, - Destinations: pm.Destinations[:], - Sources: pm.Sources[:], + Destinations: make([]string, len(pm.Destinations)), + Sources: make([]string, len(pm.Sources)), Bidirectional: pm.Bidirectional, Protocol: pm.Protocol, - Ports: pm.Ports[:], + Ports: make([]string, len(pm.Ports)), } + copy(rule.Destinations, pm.Destinations) + copy(rule.Sources, pm.Sources) + copy(rule.Ports, pm.Ports) + return rule } // ToRule converts the PolicyRule to a legacy representation of the Rule (for backwards compatibility) @@ -147,9 +151,10 @@ func (p *Policy) Copy() *Policy { Name: p.Name, Description: p.Description, Enabled: p.Enabled, + Rules: make([]*PolicyRule, len(p.Rules)), } - for _, r := range p.Rules { - c.Rules = append(c.Rules, r.Copy()) + for i, r := range p.Rules { + c.Rules[i] = r.Copy() } return c } diff --git a/management/server/rule.go b/management/server/rule.go index 68b1cc4fb..cb85d633d 100644 --- a/management/server/rule.go +++ b/management/server/rule.go @@ -45,15 +45,18 @@ type Rule struct { } func (r *Rule) Copy() *Rule { - return &Rule{ + rule := &Rule{ ID: r.ID, Name: r.Name, Description: r.Description, Disabled: r.Disabled, - Source: r.Source[:], - Destination: r.Destination[:], + Source: make([]string, len(r.Source)), + Destination: make([]string, len(r.Destination)), Flow: r.Flow, } + copy(rule.Source, r.Source) + copy(rule.Destination, r.Destination) + return rule } // EventMeta returns activity event meta related to this rule diff --git a/management/server/setupkey.go b/management/server/setupkey.go index bfba05839..ffdd822e3 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -90,8 +90,8 @@ type SetupKey struct { // Copy copies SetupKey to a new object func (key *SetupKey) Copy() *SetupKey { - autoGroups := make([]string, 0) - autoGroups = append(autoGroups, key.AutoGroups...) + autoGroups := make([]string, len(key.AutoGroups)) + copy(autoGroups, key.AutoGroups) if key.UpdatedAt.IsZero() { key.UpdatedAt = key.CreatedAt } diff --git a/management/server/user.go b/management/server/user.go index 3d0c0313e..b3556957d 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -120,9 +120,7 @@ func (u *User) Copy() *User { copy(autoGroups, u.AutoGroups) pats := make(map[string]*PersonalAccessToken, len(u.PATs)) for k, v := range u.PATs { - patCopy := new(PersonalAccessToken) - *patCopy = *v - pats[k] = patCopy + pats[k] = v.Copy() } return &User{ Id: u.Id, diff --git a/route/route.go b/route/route.go index fbd077bc2..5c45e2cf5 100644 --- a/route/route.go +++ b/route/route.go @@ -1,8 +1,9 @@ package route import ( - "github.com/netbirdio/netbird/management/server/status" "net/netip" + + "github.com/netbirdio/netbird/management/server/status" ) // Windows has some limitation regarding metric size that differ from Unix-like systems. @@ -83,7 +84,7 @@ func (r *Route) EventMeta() map[string]any { // Copy copies a route object func (r *Route) Copy() *Route { - return &Route{ + route := &Route{ ID: r.ID, Description: r.Description, NetID: r.NetID, @@ -93,8 +94,10 @@ func (r *Route) Copy() *Route { Metric: r.Metric, Masquerade: r.Masquerade, Enabled: r.Enabled, - Groups: r.Groups, + Groups: make([]string, len(r.Groups)), } + copy(route.Groups, r.Groups) + return route } // IsEqual compares one route with the other