Improve account copying (#1069)

With this fix, all nested slices and pointers will be copied by value.
Also, this fixes tests to compare the original and copy account by their
values by marshaling them to JSON strings.

Before that, they were copying the pointers that also passed the simple `=` compassion
(as the addresses match).
This commit is contained in:
Yury Gargay 2023-08-22 17:56:39 +02:00 committed by GitHub
parent 892db25021
commit e586eca16c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 88 additions and 30 deletions

View File

@ -130,16 +130,22 @@ func ParseNameServerURL(nsURL string) (NameServer, error) {
// Copy copies a nameserver group object // Copy copies a nameserver group object
func (g *NameServerGroup) Copy() *NameServerGroup { func (g *NameServerGroup) Copy() *NameServerGroup {
return &NameServerGroup{ nsGroup := &NameServerGroup{
ID: g.ID, ID: g.ID,
Name: g.Name, Name: g.Name,
Description: g.Description, Description: g.Description,
NameServers: g.NameServers, NameServers: make([]NameServer, len(g.NameServers)),
Groups: g.Groups, Groups: make([]string, len(g.Groups)),
Enabled: g.Enabled, Enabled: g.Enabled,
Primary: g.Primary, Primary: g.Primary,
Domains: g.Domains, Domains: make([]string, len(g.Domains)),
} }
copy(nsGroup.NameServers, g.NameServers)
copy(nsGroup.Groups, g.Groups)
copy(nsGroup.Domains, g.Domains)
return nsGroup
} }
// IsEqual compares one nameserver group with the other // IsEqual compares one nameserver group with the other

View File

@ -3,6 +3,7 @@ package server
import ( import (
"crypto/sha256" "crypto/sha256"
b64 "encoding/base64" b64 "encoding/base64"
"encoding/json"
"fmt" "fmt"
"net" "net"
"reflect" "reflect"
@ -1348,6 +1349,11 @@ func TestAccount_Copy(t *testing.T) {
Peers: map[string]*Peer{ Peers: map[string]*Peer{
"peer1": { "peer1": {
Key: "key1", Key: "key1",
Status: &PeerStatus{
LastSeen: time.Now(),
Connected: true,
LoginExpired: false,
},
}, },
}, },
Users: map[string]*User{ Users: map[string]*User{
@ -1370,28 +1376,36 @@ func TestAccount_Copy(t *testing.T) {
}, },
Groups: map[string]*Group{ Groups: map[string]*Group{
"group1": { "group1": {
ID: "group1", ID: "group1",
Peers: []string{"peer1"},
}, },
}, },
Rules: map[string]*Rule{ Rules: map[string]*Rule{
"rule1": { "rule1": {
ID: "rule1", ID: "rule1",
Destination: []string{},
Source: []string{},
}, },
}, },
Policies: []*Policy{ Policies: []*Policy{
{ {
ID: "policy1", ID: "policy1",
Enabled: true, Enabled: true,
Rules: make([]*PolicyRule, 0),
}, },
}, },
Routes: map[string]*route.Route{ Routes: map[string]*route.Route{
"route1": { "route1": {
ID: "route1", ID: "route1",
Groups: []string{"group1"},
}, },
}, },
NameServerGroups: map[string]*nbdns.NameServerGroup{ NameServerGroups: map[string]*nbdns.NameServerGroup{
"nsGroup1": { "nsGroup1": {
ID: "nsGroup1", ID: "nsGroup1",
Domains: []string{},
Groups: []string{},
NameServers: []nbdns.NameServer{},
}, },
}, },
DNSSettings: &DNSSettings{DisabledManagementGroups: []string{}}, DNSSettings: &DNSSettings{DisabledManagementGroups: []string{}},
@ -1402,10 +1416,20 @@ func TestAccount_Copy(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
accountCopy := account.Copy() accountCopy := account.Copy()
assert.Equal(t, account, accountCopy, "account copy returned a different value than expected") accBytes, err := json.Marshal(account)
if err != nil {
t.Fatal(err)
}
account.Peers["peer1"].Status.Connected = false // we change original object to confirm that copy wont change
accCopyBytes, err := json.Marshal(accountCopy)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, string(accBytes), string(accCopyBytes), "account copy returned a different value than expected")
} }
// hasNilField validates pointers, maps and slices if they are nil // hasNilField validates pointers, maps and slices if they are nil
// TODO: make it check nested fields too
func hasNilField(x interface{}) error { func hasNilField(x interface{}) error {
rv := reflect.ValueOf(x) rv := reflect.ValueOf(x)
rv = rv.Elem() rv = rv.Elem()

View File

@ -59,12 +59,14 @@ func (g *Group) EventMeta() map[string]any {
} }
func (g *Group) Copy() *Group { func (g *Group) Copy() *Group {
return &Group{ group := &Group{
ID: g.ID, ID: g.ID,
Name: g.Name, Name: g.Name,
Issued: g.Issued, Issued: g.Issued,
Peers: g.Peers[:], Peers: make([]string, len(g.Peers)),
} }
copy(group.Peers, g.Peers)
return group
} }
// GetGroup object of the peers // GetGroup object of the peers

View File

@ -54,6 +54,7 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{
}, },
}, },
Groups: []string{"testing"}, Groups: []string{"testing"},
Domains: []string{"domain"},
Enabled: true, Enabled: true,
} }

View File

@ -108,6 +108,10 @@ func (p *Peer) AddedWithSSOLogin() bool {
// Copy copies Peer object // Copy copies Peer object
func (p *Peer) Copy() *Peer { func (p *Peer) Copy() *Peer {
peerStatus := p.Status
if peerStatus != nil {
peerStatus = p.Status.Copy()
}
return &Peer{ return &Peer{
ID: p.ID, ID: p.ID,
Key: p.Key, Key: p.Key,
@ -115,11 +119,11 @@ func (p *Peer) Copy() *Peer {
IP: p.IP, IP: p.IP,
Meta: p.Meta, Meta: p.Meta,
Name: p.Name, Name: p.Name,
Status: p.Status, DNSLabel: p.DNSLabel,
Status: peerStatus,
UserID: p.UserID, UserID: p.UserID,
SSHKey: p.SSHKey, SSHKey: p.SSHKey,
SSHEnabled: p.SSHEnabled, SSHEnabled: p.SSHEnabled,
DNSLabel: p.DNSLabel,
LoginExpirationEnabled: p.LoginExpirationEnabled, LoginExpirationEnabled: p.LoginExpirationEnabled,
LastLogin: p.LastLogin, LastLogin: p.LastLogin,
} }

View File

@ -36,6 +36,18 @@ type PersonalAccessToken struct {
LastUsed time.Time LastUsed time.Time
} }
func (t *PersonalAccessToken) Copy() *PersonalAccessToken {
return &PersonalAccessToken{
ID: t.ID,
Name: t.Name,
HashedToken: t.HashedToken,
ExpirationDate: t.ExpirationDate,
CreatedBy: t.CreatedBy,
CreatedAt: t.CreatedAt,
LastUsed: t.LastUsed,
}
}
// PersonalAccessTokenGenerated holds the new PersonalAccessToken and the plain text version of it // PersonalAccessTokenGenerated holds the new PersonalAccessToken and the plain text version of it
type PersonalAccessTokenGenerated struct { type PersonalAccessTokenGenerated struct {
PlainToken string PlainToken string

View File

@ -95,18 +95,22 @@ type PolicyRule struct {
// Copy returns a copy of a policy rule // Copy returns a copy of a policy rule
func (pm *PolicyRule) Copy() *PolicyRule { func (pm *PolicyRule) Copy() *PolicyRule {
return &PolicyRule{ rule := &PolicyRule{
ID: pm.ID, ID: pm.ID,
Name: pm.Name, Name: pm.Name,
Description: pm.Description, Description: pm.Description,
Enabled: pm.Enabled, Enabled: pm.Enabled,
Action: pm.Action, Action: pm.Action,
Destinations: pm.Destinations[:], Destinations: make([]string, len(pm.Destinations)),
Sources: pm.Sources[:], Sources: make([]string, len(pm.Sources)),
Bidirectional: pm.Bidirectional, Bidirectional: pm.Bidirectional,
Protocol: pm.Protocol, Protocol: pm.Protocol,
Ports: pm.Ports[:], Ports: make([]string, len(pm.Ports)),
} }
copy(rule.Destinations, pm.Destinations)
copy(rule.Sources, pm.Sources)
copy(rule.Ports, pm.Ports)
return rule
} }
// ToRule converts the PolicyRule to a legacy representation of the Rule (for backwards compatibility) // ToRule converts the PolicyRule to a legacy representation of the Rule (for backwards compatibility)
@ -147,9 +151,10 @@ func (p *Policy) Copy() *Policy {
Name: p.Name, Name: p.Name,
Description: p.Description, Description: p.Description,
Enabled: p.Enabled, Enabled: p.Enabled,
Rules: make([]*PolicyRule, len(p.Rules)),
} }
for _, r := range p.Rules { for i, r := range p.Rules {
c.Rules = append(c.Rules, r.Copy()) c.Rules[i] = r.Copy()
} }
return c return c
} }

View File

@ -45,15 +45,18 @@ type Rule struct {
} }
func (r *Rule) Copy() *Rule { func (r *Rule) Copy() *Rule {
return &Rule{ rule := &Rule{
ID: r.ID, ID: r.ID,
Name: r.Name, Name: r.Name,
Description: r.Description, Description: r.Description,
Disabled: r.Disabled, Disabled: r.Disabled,
Source: r.Source[:], Source: make([]string, len(r.Source)),
Destination: r.Destination[:], Destination: make([]string, len(r.Destination)),
Flow: r.Flow, Flow: r.Flow,
} }
copy(rule.Source, r.Source)
copy(rule.Destination, r.Destination)
return rule
} }
// EventMeta returns activity event meta related to this rule // EventMeta returns activity event meta related to this rule

View File

@ -90,8 +90,8 @@ type SetupKey struct {
// Copy copies SetupKey to a new object // Copy copies SetupKey to a new object
func (key *SetupKey) Copy() *SetupKey { func (key *SetupKey) Copy() *SetupKey {
autoGroups := make([]string, 0) autoGroups := make([]string, len(key.AutoGroups))
autoGroups = append(autoGroups, key.AutoGroups...) copy(autoGroups, key.AutoGroups)
if key.UpdatedAt.IsZero() { if key.UpdatedAt.IsZero() {
key.UpdatedAt = key.CreatedAt key.UpdatedAt = key.CreatedAt
} }

View File

@ -120,9 +120,7 @@ func (u *User) Copy() *User {
copy(autoGroups, u.AutoGroups) copy(autoGroups, u.AutoGroups)
pats := make(map[string]*PersonalAccessToken, len(u.PATs)) pats := make(map[string]*PersonalAccessToken, len(u.PATs))
for k, v := range u.PATs { for k, v := range u.PATs {
patCopy := new(PersonalAccessToken) pats[k] = v.Copy()
*patCopy = *v
pats[k] = patCopy
} }
return &User{ return &User{
Id: u.Id, Id: u.Id,

View File

@ -1,8 +1,9 @@
package route package route
import ( import (
"github.com/netbirdio/netbird/management/server/status"
"net/netip" "net/netip"
"github.com/netbirdio/netbird/management/server/status"
) )
// Windows has some limitation regarding metric size that differ from Unix-like systems. // Windows has some limitation regarding metric size that differ from Unix-like systems.
@ -83,7 +84,7 @@ func (r *Route) EventMeta() map[string]any {
// Copy copies a route object // Copy copies a route object
func (r *Route) Copy() *Route { func (r *Route) Copy() *Route {
return &Route{ route := &Route{
ID: r.ID, ID: r.ID,
Description: r.Description, Description: r.Description,
NetID: r.NetID, NetID: r.NetID,
@ -93,8 +94,10 @@ func (r *Route) Copy() *Route {
Metric: r.Metric, Metric: r.Metric,
Masquerade: r.Masquerade, Masquerade: r.Masquerade,
Enabled: r.Enabled, Enabled: r.Enabled,
Groups: r.Groups, Groups: make([]string, len(r.Groups)),
} }
copy(route.Groups, r.Groups)
return route
} }
// IsEqual compares one route with the other // IsEqual compares one route with the other