Merge branch 'feature/validate-group-association' into feature/validate-group-association-debug

This commit is contained in:
bcmmbaga 2024-10-22 17:47:46 +03:00
commit 2f15708d54
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
6 changed files with 219 additions and 160 deletions

View File

@ -49,7 +49,7 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
test_client_on_docker: test_client_on_docker:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04

View File

@ -312,7 +312,6 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou
}, },
}, },
NetworkMap: &NetworkMap{}, NetworkMap: &NetworkMap{},
Checks: []*posture.Checks{},
}) })
am.peersUpdateManager.CloseChannel(ctx, peer.ID) am.peersUpdateManager.CloseChannel(ctx, peer.ID)
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
@ -1008,7 +1007,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
postureChecks := am.getPeerPostureChecks(account, p) postureChecks := am.getPeerPostureChecks(account, p)
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap, Checks: postureChecks}) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
}(peer) }(peer)
} }

View File

@ -3,11 +3,11 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"runtime/debug"
"sync" "sync"
"time" "time"
"github.com/netbirdio/netbird/management/server/differs" "github.com/netbirdio/netbird/management/server/differs"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/r3labs/diff/v3" "github.com/r3labs/diff/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -20,7 +20,6 @@ const channelBufferSize = 100
type UpdateMessage struct { type UpdateMessage struct {
Update *proto.SyncResponse Update *proto.SyncResponse
NetworkMap *NetworkMap NetworkMap *NetworkMap
Checks []*posture.Checks
} }
type PeersUpdateManager struct { type PeersUpdateManager struct {
@ -209,7 +208,7 @@ func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID
p.channelsMux.RUnlock() p.channelsMux.RUnlock()
if lastSentUpdate != nil { if lastSentUpdate != nil {
updated, err := isNewPeerUpdateMessage(lastSentUpdate, update) updated, err := isNewPeerUpdateMessage(ctx, lastSentUpdate, update)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err) log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err)
return false return false
@ -224,7 +223,14 @@ func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID
} }
// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent. // isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent.
func isNewPeerUpdateMessage(lastSentUpdate, currUpdateToSend *UpdateMessage) (bool, error) { func isNewPeerUpdateMessage(ctx context.Context, lastSentUpdate, currUpdateToSend *UpdateMessage) (isNew bool, err error) {
defer func() {
if r := recover(); r != nil {
log.WithContext(ctx).Panicf("comparing peer update messages. Trace: %s", debug.Stack())
}
isNew, err = true, nil
}()
if lastSentUpdate.Update.NetworkMap.GetSerial() > currUpdateToSend.Update.NetworkMap.GetSerial() { if lastSentUpdate.Update.NetworkMap.GetSerial() > currUpdateToSend.Update.NetworkMap.GetSerial() {
log.Tracef("new network map serial: %d not greater than last sent: %d, skip sending update", log.Tracef("new network map serial: %d not greater than last sent: %d, skip sending update",
lastSentUpdate.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial(),
@ -242,7 +248,10 @@ func isNewPeerUpdateMessage(lastSentUpdate, currUpdateToSend *UpdateMessage) (bo
return false, fmt.Errorf("failed to create differ: %v", err) return false, fmt.Errorf("failed to create differ: %v", err)
} }
changelog, err := differ.Diff(lastSentUpdate.Checks, currUpdateToSend.Checks) lastSentFiles := getChecksFiles(lastSentUpdate.Update.Checks)
currFiles := getChecksFiles(currUpdateToSend.Update.Checks)
changelog, err := differ.Diff(lastSentFiles, currFiles)
if err != nil { if err != nil {
return false, fmt.Errorf("failed to diff checks: %v", err) return false, fmt.Errorf("failed to diff checks: %v", err)
} }
@ -256,3 +265,12 @@ func isNewPeerUpdateMessage(lastSentUpdate, currUpdateToSend *UpdateMessage) (bo
} }
return len(changelog) > 0, nil return len(changelog) > 0, nil
} }
// getChecksFiles returns a list of files from the given checks.
func getChecksFiles(checks []*proto.Checks) []string {
files := make([]string, 0, len(checks))
for _, check := range checks {
files = append(files, check.GetFiles()...)
}
return files
}

View File

@ -124,14 +124,12 @@ func TestHandlePeerMessageUpdate(t *testing.T) {
NetworkMap: &proto.NetworkMap{Serial: 1}, NetworkMap: &proto.NetworkMap{Serial: 1},
}, },
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
Checks: []*posture.Checks{},
}, },
newUpdate: &UpdateMessage{ newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{ Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 1}, NetworkMap: &proto.NetworkMap{Serial: 1},
}, },
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
Checks: []*posture.Checks{},
}, },
expectedResult: false, expectedResult: false,
}, },
@ -143,14 +141,12 @@ func TestHandlePeerMessageUpdate(t *testing.T) {
NetworkMap: &proto.NetworkMap{Serial: 1}, NetworkMap: &proto.NetworkMap{Serial: 1},
}, },
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
Checks: []*posture.Checks{},
}, },
newUpdate: &UpdateMessage{ newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{ Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 2}, NetworkMap: &proto.NetworkMap{Serial: 2},
}, },
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
Checks: []*posture.Checks{{ID: "check1"}},
}, },
expectedResult: true, expectedResult: true,
}, },
@ -193,7 +189,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t) newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t) newUpdateMessage2 := createMockUpdateMessage(t)
message, err := isNewPeerUpdateMessage(newUpdateMessage1, newUpdateMessage2) message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, message) assert.False(t, message)
}) })
@ -204,7 +200,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) {
newUpdateMessage2.Update.NetworkMap.Serial++ newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(newUpdateMessage1, newUpdateMessage2) message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, message) assert.False(t, message)
}) })
@ -216,7 +212,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) {
newUpdateMessage2.NetworkMap.Routes[0].Network = netip.MustParsePrefix("1.1.1.1/32") newUpdateMessage2.NetworkMap.Routes[0].Network = netip.MustParsePrefix("1.1.1.1/32")
newUpdateMessage2.Update.NetworkMap.Serial++ newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(newUpdateMessage1, newUpdateMessage2) message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, message) assert.True(t, message)
@ -229,7 +225,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) {
newUpdateMessage2.NetworkMap.Routes[0].Groups = []string{"randomGroup1"} newUpdateMessage2.NetworkMap.Routes[0].Groups = []string{"randomGroup1"}
newUpdateMessage2.Update.NetworkMap.Serial++ newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(newUpdateMessage1, newUpdateMessage2) message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, message) assert.True(t, message)
}) })
@ -248,26 +244,63 @@ func TestIsNewPeerUpdateMessage(t *testing.T) {
newUpdateMessage2.NetworkMap.Peers = append(newUpdateMessage2.NetworkMap.Peers, newPeer) newUpdateMessage2.NetworkMap.Peers = append(newUpdateMessage2.NetworkMap.Peers, newPeer)
newUpdateMessage2.Update.NetworkMap.Serial++ newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(newUpdateMessage1, newUpdateMessage2) message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, message) assert.True(t, message)
}) })
t.Run("Updating posture checks", func(t *testing.T) { t.Run("Updating process check", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t) newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newCheck := &posture.Checks{ newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err)
assert.False(t, message)
newUpdateMessage3 := createMockUpdateMessage(t)
newUpdateMessage3.Update.Checks = []*proto.Checks{}
newUpdateMessage3.Update.NetworkMap.Serial++
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage3)
assert.NoError(t, err)
assert.True(t, message)
newUpdateMessage4 := createMockUpdateMessage(t)
check := &posture.Checks{
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{ ProcessCheck: &posture.ProcessCheck{
MinVersion: "10.0", Processes: []posture.Process{
{
LinuxPath: "/usr/local/netbird",
MacPath: "/usr/bin/netbird",
},
},
}, },
}, },
} }
newUpdateMessage2.Checks = append(newUpdateMessage2.Checks, newCheck) newUpdateMessage4.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
newUpdateMessage2.Update.NetworkMap.Serial++ newUpdateMessage4.Update.NetworkMap.Serial++
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage4)
assert.NoError(t, err)
assert.True(t, message)
message, err := isNewPeerUpdateMessage(newUpdateMessage1, newUpdateMessage2) newUpdateMessage5 := createMockUpdateMessage(t)
check = &posture.Checks{
Checks: posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{
{
LinuxPath: "/usr/bin/netbird",
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
MacPath: "/usr/local/netbird",
},
},
},
},
}
newUpdateMessage5.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
newUpdateMessage5.Update.NetworkMap.Serial++
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage5)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, message) assert.True(t, message)
}) })
@ -283,7 +316,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) {
) )
newUpdateMessage2.Update.NetworkMap.Serial++ newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(newUpdateMessage1, newUpdateMessage2) message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, message) assert.True(t, message)
}) })
@ -295,7 +328,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) {
newUpdateMessage2.NetworkMap.Peers[0].IP = net.ParseIP("192.168.1.10") newUpdateMessage2.NetworkMap.Peers[0].IP = net.ParseIP("192.168.1.10")
newUpdateMessage2.Update.NetworkMap.Serial++ newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(newUpdateMessage1, newUpdateMessage2) message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, message) assert.True(t, message)
}) })
@ -307,7 +340,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) {
newUpdateMessage2.NetworkMap.FirewallRules[0].Port = "443" newUpdateMessage2.NetworkMap.FirewallRules[0].Port = "443"
newUpdateMessage2.Update.NetworkMap.Serial++ newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(newUpdateMessage1, newUpdateMessage2) message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, message) assert.True(t, message)
}) })
@ -326,7 +359,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) {
newUpdateMessage2.NetworkMap.FirewallRules = append(newUpdateMessage2.NetworkMap.FirewallRules, newRule) newUpdateMessage2.NetworkMap.FirewallRules = append(newUpdateMessage2.NetworkMap.FirewallRules, newRule)
newUpdateMessage2.Update.NetworkMap.Serial++ newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(newUpdateMessage1, newUpdateMessage2) message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, message) assert.True(t, message)
}) })
@ -338,7 +371,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) {
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers = make([]nbdns.NameServer, 0) newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers = make([]nbdns.NameServer, 0)
newUpdateMessage2.Update.NetworkMap.Serial++ newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(newUpdateMessage1, newUpdateMessage2) message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, message) assert.True(t, message)
}) })
@ -350,7 +383,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) {
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].IP = netip.MustParseAddr("8.8.4.4") newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].IP = netip.MustParseAddr("8.8.4.4")
newUpdateMessage2.Update.NetworkMap.Serial++ newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(newUpdateMessage1, newUpdateMessage2) message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, message) assert.True(t, message)
}) })
@ -362,7 +395,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) {
newUpdateMessage2.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData = "100.64.0.2" newUpdateMessage2.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData = "100.64.0.2"
newUpdateMessage2.Update.NetworkMap.Serial++ newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(newUpdateMessage1, newUpdateMessage2) message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, message) assert.True(t, message)
}) })
@ -487,7 +520,13 @@ func createMockUpdateMessage(t *testing.T) *UpdateMessage {
{ {
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{ ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{{LinuxPath: "/usr/bin/netbird"}}, Processes: []posture.Process{
{
LinuxPath: "/usr/bin/netbird",
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
MacPath: "/usr/bin/netbird",
},
},
}, },
}, },
}, },
@ -507,6 +546,5 @@ func createMockUpdateMessage(t *testing.T) *UpdateMessage {
return &UpdateMessage{ return &UpdateMessage{
Update: toSyncResponse(context.Background(), config, peer, turnToken, relayToken, networkMap, dnsName, checks, dnsCache), Update: toSyncResponse(context.Background(), config, peer, turnToken, relayToken, networkMap, dnsName, checks, dnsCache),
NetworkMap: networkMap, NetworkMap: networkMap,
Checks: checks,
} }
} }

View File

@ -1,126 +0,0 @@
package util_test
import (
"crypto/md5"
"encoding/hex"
"io"
"os"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/netbirdio/netbird/util"
)
var _ = Describe("Client", func() {
var (
tmpDir string
)
type TestConfig struct {
SomeMap map[string]string
SomeArray []string
SomeField int
}
BeforeEach(func() {
var err error
tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*")
Expect(err).NotTo(HaveOccurred())
})
AfterEach(func() {
err := os.RemoveAll(tmpDir)
Expect(err).NotTo(HaveOccurred())
})
Describe("Config", func() {
Context("in JSON format", func() {
It("should be written and read successfully", func() {
m := make(map[string]string)
m["key1"] = "value1"
m["key2"] = "value2"
arr := []string{"value1", "value2"}
written := &TestConfig{
SomeMap: m,
SomeArray: arr,
SomeField: 99,
}
err := util.WriteJson(tmpDir+"/testconfig.json", written)
Expect(err).NotTo(HaveOccurred())
read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
Expect(err).NotTo(HaveOccurred())
Expect(read).NotTo(BeNil())
Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"]))
Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"]))
Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr))
Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField))
})
})
})
Describe("Copying file contents", func() {
Context("from one file to another", func() {
It("should be successful", func() {
src := tmpDir + "/copytest_src"
dst := tmpDir + "/copytest_dst"
err := util.WriteJson(src, []string{"1", "2", "3"})
Expect(err).NotTo(HaveOccurred())
err = util.CopyFileContents(src, dst)
Expect(err).NotTo(HaveOccurred())
hashSrc := md5.New()
hashDst := md5.New()
srcFile, err := os.Open(src)
Expect(err).NotTo(HaveOccurred())
dstFile, err := os.Open(dst)
Expect(err).NotTo(HaveOccurred())
_, err = io.Copy(hashSrc, srcFile)
Expect(err).NotTo(HaveOccurred())
_, err = io.Copy(hashDst, dstFile)
Expect(err).NotTo(HaveOccurred())
err = srcFile.Close()
Expect(err).NotTo(HaveOccurred())
err = dstFile.Close()
Expect(err).NotTo(HaveOccurred())
Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16])))
})
})
})
Describe("Handle config file without full path", func() {
Context("config file handling", func() {
It("should be successful", func() {
written := &TestConfig{
SomeField: 123,
}
cfgFile := "test_cfg.json"
defer os.Remove(cfgFile)
err := util.WriteJson(cfgFile, written)
Expect(err).NotTo(HaveOccurred())
read, err := util.ReadJson(cfgFile, &TestConfig{})
Expect(err).NotTo(HaveOccurred())
Expect(read).NotTo(BeNil())
})
})
})
})

View File

@ -1,12 +1,142 @@
package util package util
import ( import (
"crypto/md5"
"encoding/hex"
"io"
"os" "os"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/require"
) )
type TestConfig struct {
SomeMap map[string]string
SomeArray []string
SomeField int
}
func TestConfigJSON(t *testing.T) {
tests := []struct {
name string
config *TestConfig
expectedError bool
}{
{
name: "Valid JSON config",
config: &TestConfig{
SomeMap: map[string]string{"key1": "value1", "key2": "value2"},
SomeArray: []string{"value1", "value2"},
SomeField: 99,
},
expectedError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()
err := WriteJson(tmpDir+"/testconfig.json", tt.config)
require.NoError(t, err)
read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
require.NoError(t, err)
require.NotNil(t, read)
require.Equal(t, tt.config.SomeMap["key1"], read.(*TestConfig).SomeMap["key1"])
require.Equal(t, tt.config.SomeMap["key2"], read.(*TestConfig).SomeMap["key2"])
require.ElementsMatch(t, tt.config.SomeArray, read.(*TestConfig).SomeArray)
require.Equal(t, tt.config.SomeField, read.(*TestConfig).SomeField)
})
}
}
func TestCopyFileContents(t *testing.T) {
tests := []struct {
name string
srcContent []string
expectedError bool
}{
{
name: "Copy file contents successfully",
srcContent: []string{"1", "2", "3"},
expectedError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()
src := tmpDir + "/copytest_src"
dst := tmpDir + "/copytest_dst"
err := WriteJson(src, tt.srcContent)
require.NoError(t, err)
err = CopyFileContents(src, dst)
require.NoError(t, err)
hashSrc := md5.New()
hashDst := md5.New()
srcFile, err := os.Open(src)
require.NoError(t, err)
defer func() {
_ = srcFile.Close()
}()
dstFile, err := os.Open(dst)
require.NoError(t, err)
defer func() {
_ = dstFile.Close()
}()
_, err = io.Copy(hashSrc, srcFile)
require.NoError(t, err)
_, err = io.Copy(hashDst, dstFile)
require.NoError(t, err)
require.Equal(t, hex.EncodeToString(hashSrc.Sum(nil)[:16]), hex.EncodeToString(hashDst.Sum(nil)[:16]))
})
}
}
func TestHandleConfigFileWithoutFullPath(t *testing.T) {
tests := []struct {
name string
config *TestConfig
expectedError bool
}{
{
name: "Handle config file without full path",
config: &TestConfig{
SomeField: 123,
},
expectedError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfgFile := "test_cfg.json"
defer func() {
_ = os.Remove(cfgFile)
}()
err := WriteJson(cfgFile, tt.config)
require.NoError(t, err)
read, err := ReadJson(cfgFile, &TestConfig{})
require.NoError(t, err)
require.NotNil(t, read)
})
}
}
func TestReadJsonWithEnvSub(t *testing.T) { func TestReadJsonWithEnvSub(t *testing.T) {
type Config struct { type Config struct {
CertFile string `json:"CertFile"` CertFile string `json:"CertFile"`