From ce7de03d6eccd6769b50dd6ac6efbddbeaa098ec Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 9 Oct 2024 23:49:41 +0300 Subject: [PATCH] use generic differ for netip.Addr and netip.Prefix Signed-off-by: bcmmbaga --- management/server/differs/nameserver.go | 74 ------- management/server/differs/netip.go | 82 ++++++++ management/server/differs/route.go | 82 -------- management/server/updatechannel.go | 14 +- management/server/updatechannel_test.go | 258 +++++++++++++----------- 5 files changed, 235 insertions(+), 275 deletions(-) delete mode 100644 management/server/differs/nameserver.go create mode 100644 management/server/differs/netip.go delete mode 100644 management/server/differs/route.go diff --git a/management/server/differs/nameserver.go b/management/server/differs/nameserver.go deleted file mode 100644 index fdae9830e..000000000 --- a/management/server/differs/nameserver.go +++ /dev/null @@ -1,74 +0,0 @@ -package differs - -import ( - "fmt" - "reflect" - - nbdns "github.com/netbirdio/netbird/dns" - "github.com/r3labs/diff" -) - -type NameServerComparator struct{} - -func NewNameServerComparator() *NameServerComparator { - return &NameServerComparator{} -} - -func (d *NameServerComparator) Match(a, b reflect.Value) bool { - return diff.AreType(a, b, reflect.TypeOf(nbdns.NameServer{})) || - diff.AreType(a, b, reflect.TypeOf([]nbdns.NameServer{})) -} - -func (d *NameServerComparator) Diff(cl *diff.Changelog, path []string, a, b reflect.Value) error { - if err := handleInvalidKind(cl, path, a, b); err != nil { - return err - } - - if a.Kind() == reflect.Slice && b.Kind() == reflect.Slice { - return handleSliceKind(d, cl, path, a, b) - } - - ns1, ok1 := a.Interface().(nbdns.NameServer) - ns2, ok2 := b.Interface().(nbdns.NameServer) - if !ok1 || !ok2 { - return fmt.Errorf("invalid type for NameServer") - } - - if ns1.IP.String() != ns2.IP.String() { - cl.Add(diff.UPDATE, append(path, "IP"), ns1.IP.String(), ns2.IP.String()) - } - if ns1.NSType != ns2.NSType { - cl.Add(diff.UPDATE, append(path, "NSType"), ns1.NSType, ns2.NSType) - } - if ns1.Port != ns2.Port { - cl.Add(diff.UPDATE, append(path, "Port"), ns1.Port, ns2.Port) - } - - return nil -} - -func handleInvalidKind(cl *diff.Changelog, path []string, a, b reflect.Value) error { - if a.Kind() == reflect.Invalid { - cl.Add(diff.CREATE, path, nil, b.Interface()) - return fmt.Errorf("invalid kind") - } - if b.Kind() == reflect.Invalid { - cl.Add(diff.DELETE, path, a.Interface(), nil) - return fmt.Errorf("invalid kind") - } - return nil -} - -func handleSliceKind(comparator diff.ValueDiffer, cl *diff.Changelog, path []string, a, b reflect.Value) error { - if a.Len() != b.Len() { - cl.Add(diff.UPDATE, append(path, "length"), a.Len(), b.Len()) - return nil - } - - for i := 0; i < min(a.Len(), b.Len()); i++ { - if err := comparator.Diff(cl, append(path, fmt.Sprintf("[%d]", i)), a.Index(i), b.Index(i)); err != nil { - return err - } - } - return nil -} diff --git a/management/server/differs/netip.go b/management/server/differs/netip.go new file mode 100644 index 000000000..de4aa334c --- /dev/null +++ b/management/server/differs/netip.go @@ -0,0 +1,82 @@ +package differs + +import ( + "fmt" + "net/netip" + "reflect" + + "github.com/r3labs/diff/v3" +) + +// NetIPAddr is a custom differ for netip.Addr +type NetIPAddr struct { + DiffFunc func(path []string, a, b reflect.Value, p interface{}) error +} + +func (differ NetIPAddr) Match(a, b reflect.Value) bool { + return diff.AreType(a, b, reflect.TypeOf(netip.Addr{})) +} + +func (differ NetIPAddr) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error { + if a.Kind() == reflect.Invalid { + cl.Add(diff.CREATE, path, nil, b.Interface()) + return nil + } + + if b.Kind() == reflect.Invalid { + cl.Add(diff.DELETE, path, a.Interface(), nil) + return nil + } + + fromAddr, ok1 := a.Interface().(netip.Addr) + toAddr, ok2 := b.Interface().(netip.Addr) + if !ok1 || !ok2 { + return fmt.Errorf("invalid type for netip.Addr") + } + + if fromAddr.String() != toAddr.String() { + cl.Add(diff.UPDATE, path, fromAddr.String(), toAddr.String()) + } + + return nil +} + +func (differ NetIPAddr) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) { + differ.DiffFunc = dfunc //nolint +} + +// NetIPPrefix is a custom differ for netip.Prefix +type NetIPPrefix struct { + DiffFunc func(path []string, a, b reflect.Value, p interface{}) error +} + +func (differ NetIPPrefix) Match(a, b reflect.Value) bool { + return diff.AreType(a, b, reflect.TypeOf(netip.Prefix{})) +} + +func (differ NetIPPrefix) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error { + if a.Kind() == reflect.Invalid { + cl.Add(diff.CREATE, path, nil, b.Interface()) + return nil + } + if b.Kind() == reflect.Invalid { + cl.Add(diff.DELETE, path, a.Interface(), nil) + return nil + } + + fromPrefix, ok1 := a.Interface().(netip.Prefix) + toPrefix, ok2 := b.Interface().(netip.Prefix) + if !ok1 || !ok2 { + return fmt.Errorf("invalid type for netip.Addr") + } + + if fromPrefix.String() != toPrefix.String() { + cl.Add(diff.UPDATE, path, fromPrefix.String(), toPrefix.String()) + } + + return nil +} + +func (differ NetIPPrefix) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) { + differ.DiffFunc = dfunc //nolint +} diff --git a/management/server/differs/route.go b/management/server/differs/route.go deleted file mode 100644 index 95eb08882..000000000 --- a/management/server/differs/route.go +++ /dev/null @@ -1,82 +0,0 @@ -package differs - -import ( - "fmt" - "reflect" - "slices" - - nbroute "github.com/netbirdio/netbird/route" - "github.com/r3labs/diff" -) - -type RouteComparator struct{} - -func NewRouteComparator() *RouteComparator { - return &RouteComparator{} -} - -func (d *RouteComparator) Match(a, b reflect.Value) bool { - return diff.AreType(a, b, reflect.TypeOf(&nbroute.Route{})) || - diff.AreType(a, b, reflect.TypeOf([]*nbroute.Route{})) -} - -func (d *RouteComparator) Diff(cl *diff.Changelog, path []string, a, b reflect.Value) error { - if err := handleInvalidKind(cl, path, a, b); err != nil { - return err - } - - if a.Kind() == reflect.Slice && b.Kind() == reflect.Slice { - return handleSliceKind(d, cl, path, a, b) - } - - route1, ok1 := a.Interface().(*nbroute.Route) - route2, ok2 := b.Interface().(*nbroute.Route) - if !ok1 || !ok2 { - return fmt.Errorf("invalid type for Route") - } - - if route1.ID != route2.ID { - cl.Add(diff.UPDATE, append(path, "ID"), route1.ID, route2.ID) - } - if route1.AccountID != route2.AccountID { - cl.Add(diff.UPDATE, append(path, "AccountID"), route1.AccountID, route2.AccountID) - } - if route1.Network.String() != route2.Network.String() { - cl.Add(diff.UPDATE, append(path, "Network"), route1.Network.String(), route2.Network.String()) - } - if !slices.Equal(route1.Domains, route2.Domains) { - cl.Add(diff.UPDATE, append(path, "Domains"), route1.Domains, route2.Domains) - } - if route1.KeepRoute != route2.KeepRoute { - cl.Add(diff.UPDATE, append(path, "KeepRoute"), route1.KeepRoute, route2.KeepRoute) - } - if route1.NetID != route2.NetID { - cl.Add(diff.UPDATE, append(path, "NetID"), route1.NetID, route2.NetID) - } - if route1.Description != route2.Description { - cl.Add(diff.UPDATE, append(path, "Description"), route1.Description, route2.Description) - } - if route1.Peer != route2.Peer { - cl.Add(diff.UPDATE, append(path, "Peer"), route1.Peer, route2.Peer) - } - if !slices.Equal(route1.PeerGroups, route2.PeerGroups) { - cl.Add(diff.UPDATE, append(path, "PeerGroups"), route1.PeerGroups, route2.PeerGroups) - } - if route1.NetworkType != route2.NetworkType { - cl.Add(diff.UPDATE, append(path, "NetworkType"), route1.NetworkType, route2.NetworkType) - } - if route1.Masquerade != route2.Masquerade { - cl.Add(diff.UPDATE, append(path, "Masquerade"), route1.Masquerade, route2.Masquerade) - } - if route1.Metric != route2.Metric { - cl.Add(diff.UPDATE, append(path, "Metric"), route1.Metric, route2.Metric) - } - if route1.Enabled != route2.Enabled { - cl.Add(diff.UPDATE, append(path, "Enabled"), route1.Enabled, route2.Enabled) - } - if !slices.Equal(route1.Groups, route2.Groups) { - cl.Add(diff.UPDATE, append(path, "Groups"), route1.Groups, route2.Groups) - } - - return nil -} diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index 29abfa865..c34ee977b 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/netbirdio/netbird/management/server/differs" "github.com/netbirdio/netbird/management/server/posture" "github.com/r3labs/diff/v3" log "github.com/sirupsen/logrus" @@ -228,7 +229,16 @@ func isNewPeerUpdateMessage(lastSentUpdate, currUpdateToSend *UpdateMessage) (bo return false, nil } - changelog, err := diff.Diff(lastSentUpdate.Checks, currUpdateToSend.Checks) + differ, err := diff.NewDiffer( + diff.DisableStructValues(), + diff.CustomValueDiffers(&differs.NetIPAddr{}), + diff.CustomValueDiffers(&differs.NetIPPrefix{}), + ) + if err != nil { + return false, fmt.Errorf("failed to create differ: %v", err) + } + + changelog, err := differ.Diff(lastSentUpdate.Checks, currUpdateToSend.Checks) if err != nil { return false, fmt.Errorf("failed to diff checks: %v", err) } @@ -236,7 +246,7 @@ func isNewPeerUpdateMessage(lastSentUpdate, currUpdateToSend *UpdateMessage) (bo return true, nil } - changelog, err = diff.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap) + changelog, err = differ.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap) if err != nil { return false, fmt.Errorf("failed to diff network map: %v", err) } diff --git a/management/server/updatechannel_test.go b/management/server/updatechannel_test.go index 461b9f9f8..1ff781dda 100644 --- a/management/server/updatechannel_test.go +++ b/management/server/updatechannel_test.go @@ -8,9 +8,12 @@ import ( "time" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + nbroute "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/util" "github.com/stretchr/testify/assert" ) @@ -368,121 +371,142 @@ func TestIsNewPeerUpdateMessage(t *testing.T) { func createMockUpdateMessage(t *testing.T) *UpdateMessage { t.Helper() - //_, ipNet, err := net.ParseCIDR("192.168.1.0/24") - //if err != nil { - // t.Fatal(err) - //} - //domainList, err := domain.FromStringList([]string{"example.com"}) - //if err != nil { - // t.Fatal(err) - //} - // - //config := &Config{ - // Signal: &Host{ - // Proto: "https", - // URI: "signal.uri", - // Username: "", - // Password: "", - // }, - // Stuns: []*Host{{URI: "stun.uri", Proto: UDP}}, - // TURNConfig: &TURNConfig{ - // Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}}, - // }, - //} - //peer := &nbpeer.Peer{ - // IP: net.ParseIP("192.168.1.1"), - // SSHEnabled: true, - // Key: "peer-key", - // DNSLabel: "peer1", - // SSHKey: "peer1-ssh-key", - //} - // - ////NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *TURNConfig, relayCfg *Relay) - ////turnCredentials := &TURNCredentials{ - //// Username: "turn-user", - //// Password: "turn-pass", - ////} - // - //networkMap := &NetworkMap{ - // Network: &Network{Net: *ipNet, Serial: 1000}, - // Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}}, - // OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}}, - // Routes: []*nbroute.Route{ - // { - // ID: "route1", - // Network: netip.MustParsePrefix("10.0.0.0/24"), - // KeepRoute: true, - // NetID: "route1", - // Peer: "peer1", - // NetworkType: 1, - // Masquerade: true, - // Metric: 9999, - // Enabled: true, - // Groups: []string{"test1", "test2"}, - // }, - // { - // ID: "route2", - // Domains: domainList, - // KeepRoute: true, - // NetID: "route2", - // Peer: "peer1", - // NetworkType: 1, - // Masquerade: true, - // Metric: 9999, - // Enabled: true, - // Groups: []string{"test1", "test2"}, - // }, - // }, - // DNSConfig: nbdns.Config{ - // ServiceEnable: true, - // NameServerGroups: []*nbdns.NameServerGroup{ - // { - // NameServers: []nbdns.NameServer{{ - // IP: netip.MustParseAddr("8.8.8.8"), - // NSType: nbdns.UDPNameServerType, - // Port: nbdns.DefaultDNSPort, - // }}, - // Primary: true, - // Domains: []string{"example.com"}, - // Enabled: true, - // SearchDomainsEnabled: true, - // }, - // { - // ID: "ns1", - // NameServers: []nbdns.NameServer{{ - // IP: netip.MustParseAddr("1.1.1.1"), - // NSType: nbdns.UDPNameServerType, - // Port: nbdns.DefaultDNSPort, - // }}, - // Groups: []string{"group1"}, - // Primary: true, - // Domains: []string{"example.com"}, - // Enabled: true, - // SearchDomainsEnabled: true, - // }, - // }, - // CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}}, - // }, - // FirewallRules: []*FirewallRule{ - // {PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"}, - // }, - //} - //dnsName := "example.com" - //checks := []*posture.Checks{ - // { - // Checks: posture.ChecksDefinition{ - // ProcessCheck: &posture.ProcessCheck{ - // Processes: []posture.Process{{LinuxPath: "/usr/bin/netbird"}}, - // }, - // }, - // }, - //} - //dnsCache := &DNSConfigCache{} - // - //return &UpdateMessage{ - // //Update: toSyncResponse(context.Background(), config, peer, turnCredentials, networkMap, dnsName, checks, dnsCache), - // NetworkMap: networkMap, - // Checks: checks, - //} - return nil + + _, ipNet, err := net.ParseCIDR("192.168.1.0/24") + if err != nil { + t.Fatal(err) + } + domainList, err := domain.FromStringList([]string{"example.com"}) + if err != nil { + t.Fatal(err) + } + + config := &Config{ + Signal: &Host{ + Proto: "https", + URI: "signal.uri", + Username: "", + Password: "", + }, + Stuns: []*Host{{URI: "stun.uri", Proto: UDP}}, + TURNConfig: &TURNConfig{ + Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}}, + }, + } + peer := &nbpeer.Peer{ + IP: net.ParseIP("192.168.1.1"), + SSHEnabled: true, + Key: "peer-key", + DNSLabel: "peer1", + SSHKey: "peer1-ssh-key", + } + + secretManager := NewTimeBasedAuthSecretsManager( + NewPeersUpdateManager(nil), + &TURNConfig{ + TimeBasedCredentials: false, + CredentialsTTL: util.Duration{ + Duration: defaultDuration, + }, + Secret: "secret", + Turns: []*Host{TurnTestHost}, + }, + &Relay{ + Addresses: []string{"localhost:0"}, + CredentialsTTL: util.Duration{Duration: time.Hour}, + Secret: "secret", + }, + ) + + networkMap := &NetworkMap{ + Network: &Network{Net: *ipNet, Serial: 1000}, + Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}}, + OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}}, + Routes: []*nbroute.Route{ + { + ID: "route1", + Network: netip.MustParsePrefix("10.0.0.0/24"), + KeepRoute: true, + NetID: "route1", + Peer: "peer1", + NetworkType: 1, + Masquerade: true, + Metric: 9999, + Enabled: true, + Groups: []string{"test1", "test2"}, + }, + { + ID: "route2", + Domains: domainList, + KeepRoute: true, + NetID: "route2", + Peer: "peer1", + NetworkType: 1, + Masquerade: true, + Metric: 9999, + Enabled: true, + Groups: []string{"test1", "test2"}, + }, + }, + DNSConfig: nbdns.Config{ + ServiceEnable: true, + NameServerGroups: []*nbdns.NameServerGroup{ + { + NameServers: []nbdns.NameServer{{ + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + Primary: true, + Domains: []string{"example.com"}, + Enabled: true, + SearchDomainsEnabled: true, + }, + { + ID: "ns1", + NameServers: []nbdns.NameServer{{ + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + Groups: []string{"group1"}, + Primary: true, + Domains: []string{"example.com"}, + Enabled: true, + SearchDomainsEnabled: true, + }, + }, + CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}}, + }, + FirewallRules: []*FirewallRule{ + {PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"}, + }, + } + dnsName := "example.com" + checks := []*posture.Checks{ + { + Checks: posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{{LinuxPath: "/usr/bin/netbird"}}, + }, + }, + }, + } + dnsCache := &DNSConfigCache{} + + turnToken, err := secretManager.GenerateTurnToken() + if err != nil { + t.Fatal(err) + } + + relayToken, err := secretManager.GenerateRelayToken() + if err != nil { + t.Fatal(err) + } + + return &UpdateMessage{ + Update: toSyncResponse(context.Background(), config, peer, turnToken, relayToken, networkMap, dnsName, checks, dnsCache), + NetworkMap: networkMap, + Checks: checks, + } }