mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-21 23:53:14 +01:00
Improve account copying (#1069)
With this fix, all nested slices and pointers will be copied by value. Also, this fixes tests to compare the original and copy account by their values by marshaling them to JSON strings. Before that, they were copying the pointers that also passed the simple `=` compassion (as the addresses match).
This commit is contained in:
parent
892db25021
commit
e586eca16c
@ -130,16 +130,22 @@ func ParseNameServerURL(nsURL string) (NameServer, error) {
|
||||
|
||||
// Copy copies a nameserver group object
|
||||
func (g *NameServerGroup) Copy() *NameServerGroup {
|
||||
return &NameServerGroup{
|
||||
nsGroup := &NameServerGroup{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: g.Description,
|
||||
NameServers: g.NameServers,
|
||||
Groups: g.Groups,
|
||||
NameServers: make([]NameServer, len(g.NameServers)),
|
||||
Groups: make([]string, len(g.Groups)),
|
||||
Enabled: g.Enabled,
|
||||
Primary: g.Primary,
|
||||
Domains: g.Domains,
|
||||
Domains: make([]string, len(g.Domains)),
|
||||
}
|
||||
|
||||
copy(nsGroup.NameServers, g.NameServers)
|
||||
copy(nsGroup.Groups, g.Groups)
|
||||
copy(nsGroup.Domains, g.Domains)
|
||||
|
||||
return nsGroup
|
||||
}
|
||||
|
||||
// IsEqual compares one nameserver group with the other
|
||||
|
@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
@ -1348,6 +1349,11 @@ func TestAccount_Copy(t *testing.T) {
|
||||
Peers: map[string]*Peer{
|
||||
"peer1": {
|
||||
Key: "key1",
|
||||
Status: &PeerStatus{
|
||||
LastSeen: time.Now(),
|
||||
Connected: true,
|
||||
LoginExpired: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
Users: map[string]*User{
|
||||
@ -1371,27 +1377,35 @@ func TestAccount_Copy(t *testing.T) {
|
||||
Groups: map[string]*Group{
|
||||
"group1": {
|
||||
ID: "group1",
|
||||
Peers: []string{"peer1"},
|
||||
},
|
||||
},
|
||||
Rules: map[string]*Rule{
|
||||
"rule1": {
|
||||
ID: "rule1",
|
||||
Destination: []string{},
|
||||
Source: []string{},
|
||||
},
|
||||
},
|
||||
Policies: []*Policy{
|
||||
{
|
||||
ID: "policy1",
|
||||
Enabled: true,
|
||||
Rules: make([]*PolicyRule, 0),
|
||||
},
|
||||
},
|
||||
Routes: map[string]*route.Route{
|
||||
"route1": {
|
||||
ID: "route1",
|
||||
Groups: []string{"group1"},
|
||||
},
|
||||
},
|
||||
NameServerGroups: map[string]*nbdns.NameServerGroup{
|
||||
"nsGroup1": {
|
||||
ID: "nsGroup1",
|
||||
Domains: []string{},
|
||||
Groups: []string{},
|
||||
NameServers: []nbdns.NameServer{},
|
||||
},
|
||||
},
|
||||
DNSSettings: &DNSSettings{DisabledManagementGroups: []string{}},
|
||||
@ -1402,10 +1416,20 @@ func TestAccount_Copy(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
accountCopy := account.Copy()
|
||||
assert.Equal(t, account, accountCopy, "account copy returned a different value than expected")
|
||||
accBytes, err := json.Marshal(account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
account.Peers["peer1"].Status.Connected = false // we change original object to confirm that copy wont change
|
||||
accCopyBytes, err := json.Marshal(accountCopy)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Equal(t, string(accBytes), string(accCopyBytes), "account copy returned a different value than expected")
|
||||
}
|
||||
|
||||
// hasNilField validates pointers, maps and slices if they are nil
|
||||
// TODO: make it check nested fields too
|
||||
func hasNilField(x interface{}) error {
|
||||
rv := reflect.ValueOf(x)
|
||||
rv = rv.Elem()
|
||||
|
@ -59,12 +59,14 @@ func (g *Group) EventMeta() map[string]any {
|
||||
}
|
||||
|
||||
func (g *Group) Copy() *Group {
|
||||
return &Group{
|
||||
group := &Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Issued: g.Issued,
|
||||
Peers: g.Peers[:],
|
||||
Peers: make([]string, len(g.Peers)),
|
||||
}
|
||||
copy(group.Peers, g.Peers)
|
||||
return group
|
||||
}
|
||||
|
||||
// GetGroup object of the peers
|
||||
|
@ -54,6 +54,7 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{
|
||||
},
|
||||
},
|
||||
Groups: []string{"testing"},
|
||||
Domains: []string{"domain"},
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
|
@ -108,6 +108,10 @@ func (p *Peer) AddedWithSSOLogin() bool {
|
||||
|
||||
// Copy copies Peer object
|
||||
func (p *Peer) Copy() *Peer {
|
||||
peerStatus := p.Status
|
||||
if peerStatus != nil {
|
||||
peerStatus = p.Status.Copy()
|
||||
}
|
||||
return &Peer{
|
||||
ID: p.ID,
|
||||
Key: p.Key,
|
||||
@ -115,11 +119,11 @@ func (p *Peer) Copy() *Peer {
|
||||
IP: p.IP,
|
||||
Meta: p.Meta,
|
||||
Name: p.Name,
|
||||
Status: p.Status,
|
||||
DNSLabel: p.DNSLabel,
|
||||
Status: peerStatus,
|
||||
UserID: p.UserID,
|
||||
SSHKey: p.SSHKey,
|
||||
SSHEnabled: p.SSHEnabled,
|
||||
DNSLabel: p.DNSLabel,
|
||||
LoginExpirationEnabled: p.LoginExpirationEnabled,
|
||||
LastLogin: p.LastLogin,
|
||||
}
|
||||
|
@ -36,6 +36,18 @@ type PersonalAccessToken struct {
|
||||
LastUsed time.Time
|
||||
}
|
||||
|
||||
func (t *PersonalAccessToken) Copy() *PersonalAccessToken {
|
||||
return &PersonalAccessToken{
|
||||
ID: t.ID,
|
||||
Name: t.Name,
|
||||
HashedToken: t.HashedToken,
|
||||
ExpirationDate: t.ExpirationDate,
|
||||
CreatedBy: t.CreatedBy,
|
||||
CreatedAt: t.CreatedAt,
|
||||
LastUsed: t.LastUsed,
|
||||
}
|
||||
}
|
||||
|
||||
// PersonalAccessTokenGenerated holds the new PersonalAccessToken and the plain text version of it
|
||||
type PersonalAccessTokenGenerated struct {
|
||||
PlainToken string
|
||||
|
@ -95,18 +95,22 @@ type PolicyRule struct {
|
||||
|
||||
// Copy returns a copy of a policy rule
|
||||
func (pm *PolicyRule) Copy() *PolicyRule {
|
||||
return &PolicyRule{
|
||||
rule := &PolicyRule{
|
||||
ID: pm.ID,
|
||||
Name: pm.Name,
|
||||
Description: pm.Description,
|
||||
Enabled: pm.Enabled,
|
||||
Action: pm.Action,
|
||||
Destinations: pm.Destinations[:],
|
||||
Sources: pm.Sources[:],
|
||||
Destinations: make([]string, len(pm.Destinations)),
|
||||
Sources: make([]string, len(pm.Sources)),
|
||||
Bidirectional: pm.Bidirectional,
|
||||
Protocol: pm.Protocol,
|
||||
Ports: pm.Ports[:],
|
||||
Ports: make([]string, len(pm.Ports)),
|
||||
}
|
||||
copy(rule.Destinations, pm.Destinations)
|
||||
copy(rule.Sources, pm.Sources)
|
||||
copy(rule.Ports, pm.Ports)
|
||||
return rule
|
||||
}
|
||||
|
||||
// ToRule converts the PolicyRule to a legacy representation of the Rule (for backwards compatibility)
|
||||
@ -147,9 +151,10 @@ func (p *Policy) Copy() *Policy {
|
||||
Name: p.Name,
|
||||
Description: p.Description,
|
||||
Enabled: p.Enabled,
|
||||
Rules: make([]*PolicyRule, len(p.Rules)),
|
||||
}
|
||||
for _, r := range p.Rules {
|
||||
c.Rules = append(c.Rules, r.Copy())
|
||||
for i, r := range p.Rules {
|
||||
c.Rules[i] = r.Copy()
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
@ -45,15 +45,18 @@ type Rule struct {
|
||||
}
|
||||
|
||||
func (r *Rule) Copy() *Rule {
|
||||
return &Rule{
|
||||
rule := &Rule{
|
||||
ID: r.ID,
|
||||
Name: r.Name,
|
||||
Description: r.Description,
|
||||
Disabled: r.Disabled,
|
||||
Source: r.Source[:],
|
||||
Destination: r.Destination[:],
|
||||
Source: make([]string, len(r.Source)),
|
||||
Destination: make([]string, len(r.Destination)),
|
||||
Flow: r.Flow,
|
||||
}
|
||||
copy(rule.Source, r.Source)
|
||||
copy(rule.Destination, r.Destination)
|
||||
return rule
|
||||
}
|
||||
|
||||
// EventMeta returns activity event meta related to this rule
|
||||
|
@ -90,8 +90,8 @@ type SetupKey struct {
|
||||
|
||||
// Copy copies SetupKey to a new object
|
||||
func (key *SetupKey) Copy() *SetupKey {
|
||||
autoGroups := make([]string, 0)
|
||||
autoGroups = append(autoGroups, key.AutoGroups...)
|
||||
autoGroups := make([]string, len(key.AutoGroups))
|
||||
copy(autoGroups, key.AutoGroups)
|
||||
if key.UpdatedAt.IsZero() {
|
||||
key.UpdatedAt = key.CreatedAt
|
||||
}
|
||||
|
@ -120,9 +120,7 @@ func (u *User) Copy() *User {
|
||||
copy(autoGroups, u.AutoGroups)
|
||||
pats := make(map[string]*PersonalAccessToken, len(u.PATs))
|
||||
for k, v := range u.PATs {
|
||||
patCopy := new(PersonalAccessToken)
|
||||
*patCopy = *v
|
||||
pats[k] = patCopy
|
||||
pats[k] = v.Copy()
|
||||
}
|
||||
return &User{
|
||||
Id: u.Id,
|
||||
|
@ -1,8 +1,9 @@
|
||||
package route
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
// Windows has some limitation regarding metric size that differ from Unix-like systems.
|
||||
@ -83,7 +84,7 @@ func (r *Route) EventMeta() map[string]any {
|
||||
|
||||
// Copy copies a route object
|
||||
func (r *Route) Copy() *Route {
|
||||
return &Route{
|
||||
route := &Route{
|
||||
ID: r.ID,
|
||||
Description: r.Description,
|
||||
NetID: r.NetID,
|
||||
@ -93,8 +94,10 @@ func (r *Route) Copy() *Route {
|
||||
Metric: r.Metric,
|
||||
Masquerade: r.Masquerade,
|
||||
Enabled: r.Enabled,
|
||||
Groups: r.Groups,
|
||||
Groups: make([]string, len(r.Groups)),
|
||||
}
|
||||
copy(route.Groups, r.Groups)
|
||||
return route
|
||||
}
|
||||
|
||||
// IsEqual compares one route with the other
|
||||
|
Loading…
Reference in New Issue
Block a user