mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-10 15:48:29 +02:00
Merge branch 'main' into test/add-api-component-test
This commit is contained in:
41
.github/workflows/golang-test-linux.yml
vendored
41
.github/workflows/golang-test-linux.yml
vendored
@ -52,6 +52,47 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
|
||||||
|
|
||||||
|
benchmark:
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
store: [ 'sqlite', 'postgres' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/go/pkg/mod
|
||||||
|
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
steps:
|
steps:
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -225,3 +227,105 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func runIptablesSave(t *testing.T) (string, string) {
|
||||||
|
t.Helper()
|
||||||
|
var stdout, stderr bytes.Buffer
|
||||||
|
cmd := exec.Command("iptables-save")
|
||||||
|
cmd.Stdout = &stdout
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
|
||||||
|
err := cmd.Run()
|
||||||
|
require.NoError(t, err, "iptables-save failed to run")
|
||||||
|
|
||||||
|
return stdout.String(), stderr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyIptablesOutput(t *testing.T, stdout, stderr string) {
|
||||||
|
t.Helper()
|
||||||
|
// Check for any incompatibility warnings
|
||||||
|
require.NotContains(t,
|
||||||
|
stderr,
|
||||||
|
"incompatible",
|
||||||
|
"iptables-save produced compatibility warning. Full stderr: %s",
|
||||||
|
stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Verify standard tables are present
|
||||||
|
expectedTables := []string{
|
||||||
|
"*filter",
|
||||||
|
"*nat",
|
||||||
|
"*mangle",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, table := range expectedTables {
|
||||||
|
require.Contains(t,
|
||||||
|
stdout,
|
||||||
|
table,
|
||||||
|
"iptables-save output missing expected table: %s\nFull stdout: %s",
|
||||||
|
table,
|
||||||
|
stdout,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := exec.LookPath("iptables-save"); err != nil {
|
||||||
|
t.Skipf("iptables-save not available on this system: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// First ensure iptables-nft tables exist by running iptables-save
|
||||||
|
stdout, stderr := runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock)
|
||||||
|
require.NoError(t, err, "failed to create manager")
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := manager.Reset(nil)
|
||||||
|
require.NoError(t, err, "failed to reset manager state")
|
||||||
|
|
||||||
|
// Verify iptables output after reset
|
||||||
|
stdout, stderr := runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
})
|
||||||
|
|
||||||
|
ip := net.ParseIP("100.96.0.1")
|
||||||
|
_, err = manager.AddPeerFiltering(
|
||||||
|
ip,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []int{80}},
|
||||||
|
fw.RuleDirectionIN,
|
||||||
|
fw.ActionAccept,
|
||||||
|
"",
|
||||||
|
"test rule",
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "failed to add peer filtering rule")
|
||||||
|
|
||||||
|
_, err = manager.AddRouteFiltering(
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []int{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "failed to add route filtering rule")
|
||||||
|
|
||||||
|
pair := fw.RouterPair{
|
||||||
|
Source: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
Destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
Masquerade: true,
|
||||||
|
}
|
||||||
|
err = manager.AddNatRule(pair)
|
||||||
|
require.NoError(t, err, "failed to add NAT rule")
|
||||||
|
|
||||||
|
stdout, stderr = runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
}
|
||||||
|
@ -239,7 +239,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
// SetLegacyManagement doesn't need to be implemented for this manager
|
// SetLegacyManagement doesn't need to be implemented for this manager
|
||||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeFirewall == nil {
|
||||||
return errRouteNotSupported
|
return nil
|
||||||
}
|
}
|
||||||
return m.nativeFirewall.SetLegacyManagement(isLegacy)
|
return m.nativeFirewall.SetLegacyManagement(isLegacy)
|
||||||
}
|
}
|
||||||
|
12
client/iface/bind/control_android.go
Normal file
12
client/iface/bind/control_android.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// ControlFns is not thread safe and should only be modified during init.
|
||||||
|
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
|
||||||
|
}
|
@ -55,7 +55,7 @@ type ruleParams struct {
|
|||||||
|
|
||||||
// isLegacy determines whether to use the legacy routing setup
|
// isLegacy determines whether to use the legacy routing setup
|
||||||
func isLegacy() bool {
|
func isLegacy() bool {
|
||||||
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true"
|
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark()
|
||||||
}
|
}
|
||||||
|
|
||||||
// setIsLegacy sets the legacy routing setup
|
// setIsLegacy sets the legacy routing setup
|
||||||
@ -92,17 +92,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
|||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = addRoutingTableName(); err != nil {
|
|
||||||
log.Errorf("Error adding routing table name: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
originalValues, err := sysctl.Setup(r.wgInterface)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Error setting up sysctl: %v", err)
|
|
||||||
sysctlFailed = true
|
|
||||||
}
|
|
||||||
originalSysctl = originalValues
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
|
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
|
||||||
@ -123,6 +112,17 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err = addRoutingTableName(); err != nil {
|
||||||
|
log.Errorf("Error adding routing table name: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
originalValues, err := sysctl.Setup(r.wgInterface)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Error setting up sysctl: %v", err)
|
||||||
|
sysctlFailed = true
|
||||||
|
}
|
||||||
|
originalSysctl = originalValues
|
||||||
|
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -450,7 +450,7 @@ func addRule(params ruleParams) error {
|
|||||||
rule.Invert = params.invert
|
rule.Invert = params.invert
|
||||||
rule.SuppressPrefixlen = params.suppressPrefix
|
rule.SuppressPrefixlen = params.suppressPrefix
|
||||||
|
|
||||||
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) {
|
||||||
return fmt.Errorf("add routing rule: %w", err)
|
return fmt.Errorf("add routing rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -467,7 +467,7 @@ func removeRule(params ruleParams) error {
|
|||||||
rule.Priority = params.priority
|
rule.Priority = params.priority
|
||||||
rule.SuppressPrefixlen = params.suppressPrefix
|
rule.SuppressPrefixlen = params.suppressPrefix
|
||||||
|
|
||||||
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) {
|
||||||
return fmt.Errorf("remove routing rule: %w", err)
|
return fmt.Errorf("remove routing rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
2
go.mod
2
go.mod
@ -236,7 +236,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
|
|||||||
|
|
||||||
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
|
||||||
|
|
||||||
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73
|
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9
|
||||||
|
|
||||||
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||||
|
|
||||||
|
4
go.sum
4
go.sum
@ -527,8 +527,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
|
|||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 h1:jayg97LH/jJlvpIHVxueTfa+tfQ+FY8fy2sIhCwkz0g=
|
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY=
|
||||||
github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
|
||||||
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4=
|
||||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||||
|
@ -113,7 +113,7 @@ type AccountManager interface {
|
|||||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||||
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||||
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
|
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
|
||||||
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
|
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error)
|
||||||
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
||||||
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
|
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
|
||||||
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||||
@ -139,7 +139,7 @@ type AccountManager interface {
|
|||||||
HasConnectedChannel(peerID string) bool
|
HasConnectedChannel(peerID string) bool
|
||||||
GetExternalCacheManager() ExternalCacheManager
|
GetExternalCacheManager() ExternalCacheManager
|
||||||
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||||
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
|
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
|
||||||
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
|
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||||
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||||
GetIdpManager() idp.Manager
|
GetIdpManager() idp.Manager
|
||||||
|
@ -6,13 +6,17 @@ import (
|
|||||||
b64 "encoding/base64"
|
b64 "encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
@ -1196,8 +1200,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
policy := Policy{
|
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
ID: "policy",
|
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
@ -1208,8 +1211,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
})
|
||||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
@ -1278,19 +1280,6 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
|||||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
|
|
||||||
policy := Policy{
|
|
||||||
Enabled: true,
|
|
||||||
Rules: []*PolicyRule{
|
|
||||||
{
|
|
||||||
Enabled: true,
|
|
||||||
Sources: []string{"groupA"},
|
|
||||||
Destinations: []string{"groupA"},
|
|
||||||
Bidirectional: true,
|
|
||||||
Action: PolicyTrafficActionAccept,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
@ -1303,7 +1292,19 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupA"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
t.Errorf("delete default rule: %v", err)
|
t.Errorf("delete default rule: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -1324,7 +1325,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
policy := Policy{
|
_, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
@ -1335,9 +1336,8 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
})
|
||||||
|
if err != nil {
|
||||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
|
||||||
t.Errorf("save policy: %v", err)
|
t.Errorf("save policy: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -1379,7 +1379,12 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
|||||||
|
|
||||||
require.NoError(t, err, "failed to save group")
|
require.NoError(t, err, "failed to save group")
|
||||||
|
|
||||||
policy := Policy{
|
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||||
|
t.Errorf("delete default rule: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
@ -1390,14 +1395,8 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
})
|
||||||
|
if err != nil {
|
||||||
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
|
||||||
t.Errorf("delete default rule: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
|
|
||||||
t.Errorf("save policy: %v", err)
|
t.Errorf("save policy: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -2945,3 +2944,218 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage)
|
|||||||
t.Error("Timed out waiting for update message")
|
t.Error("Timed out waiting for update message")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
||||||
|
benchCases := []struct {
|
||||||
|
name string
|
||||||
|
peers int
|
||||||
|
groups int
|
||||||
|
// We need different expectations for CI/CD and local runs because of the different performance characteristics
|
||||||
|
minMsPerOpLocal float64
|
||||||
|
maxMsPerOpLocal float64
|
||||||
|
minMsPerOpCICD float64
|
||||||
|
maxMsPerOpCICD float64
|
||||||
|
}{
|
||||||
|
{"Small", 50, 5, 1, 3, 4, 10},
|
||||||
|
{"Medium", 500, 100, 7, 13, 10, 60},
|
||||||
|
{"Large", 5000, 200, 65, 80, 60, 170},
|
||||||
|
{"Small single", 50, 10, 1, 3, 4, 60},
|
||||||
|
{"Medium single", 500, 10, 7, 13, 10, 26},
|
||||||
|
{"Large 5", 5000, 15, 65, 80, 60, 170},
|
||||||
|
}
|
||||||
|
|
||||||
|
log.SetOutput(io.Discard)
|
||||||
|
defer log.SetOutput(os.Stderr)
|
||||||
|
for _, bc := range benchCases {
|
||||||
|
b.Run(bc.name, func(b *testing.B) {
|
||||||
|
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to get account: %v", err)
|
||||||
|
}
|
||||||
|
peerChannels := make(map[string]chan *UpdateMessage)
|
||||||
|
for peerID := range account.Peers {
|
||||||
|
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||||
|
}
|
||||||
|
manager.peersUpdateManager.peerChannels = peerChannels
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
start := time.Now()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
|
||||||
|
assert.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
duration := time.Since(start)
|
||||||
|
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||||
|
b.ReportMetric(msPerOp, "ms/op")
|
||||||
|
|
||||||
|
minExpected := bc.minMsPerOpLocal
|
||||||
|
maxExpected := bc.maxMsPerOpLocal
|
||||||
|
if os.Getenv("CI") == "true" {
|
||||||
|
minExpected = bc.minMsPerOpCICD
|
||||||
|
maxExpected = bc.maxMsPerOpCICD
|
||||||
|
}
|
||||||
|
|
||||||
|
if msPerOp < minExpected {
|
||||||
|
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msPerOp > maxExpected {
|
||||||
|
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
|
||||||
|
benchCases := []struct {
|
||||||
|
name string
|
||||||
|
peers int
|
||||||
|
groups int
|
||||||
|
// We need different expectations for CI/CD and local runs because of the different performance characteristics
|
||||||
|
minMsPerOpLocal float64
|
||||||
|
maxMsPerOpLocal float64
|
||||||
|
minMsPerOpCICD float64
|
||||||
|
maxMsPerOpCICD float64
|
||||||
|
}{
|
||||||
|
{"Small", 50, 5, 102, 110, 102, 120},
|
||||||
|
{"Medium", 500, 100, 105, 140, 105, 170},
|
||||||
|
{"Large", 5000, 200, 160, 200, 160, 270},
|
||||||
|
{"Small single", 50, 10, 102, 110, 102, 120},
|
||||||
|
{"Medium single", 500, 10, 105, 140, 105, 170},
|
||||||
|
{"Large 5", 5000, 15, 160, 200, 160, 270},
|
||||||
|
}
|
||||||
|
|
||||||
|
log.SetOutput(io.Discard)
|
||||||
|
defer log.SetOutput(os.Stderr)
|
||||||
|
for _, bc := range benchCases {
|
||||||
|
b.Run(bc.name, func(b *testing.B) {
|
||||||
|
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to get account: %v", err)
|
||||||
|
}
|
||||||
|
peerChannels := make(map[string]chan *UpdateMessage)
|
||||||
|
for peerID := range account.Peers {
|
||||||
|
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||||
|
}
|
||||||
|
manager.peersUpdateManager.peerChannels = peerChannels
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
start := time.Now()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{
|
||||||
|
WireGuardPubKey: account.Peers["peer-1"].Key,
|
||||||
|
SSHKey: "someKey",
|
||||||
|
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
|
||||||
|
UserID: "regular_user",
|
||||||
|
SetupKey: "",
|
||||||
|
ConnectionIP: net.IP{1, 1, 1, 1},
|
||||||
|
})
|
||||||
|
assert.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
duration := time.Since(start)
|
||||||
|
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||||
|
b.ReportMetric(msPerOp, "ms/op")
|
||||||
|
|
||||||
|
minExpected := bc.minMsPerOpLocal
|
||||||
|
maxExpected := bc.maxMsPerOpLocal
|
||||||
|
if os.Getenv("CI") == "true" {
|
||||||
|
minExpected = bc.minMsPerOpCICD
|
||||||
|
maxExpected = bc.maxMsPerOpCICD
|
||||||
|
}
|
||||||
|
|
||||||
|
if msPerOp < minExpected {
|
||||||
|
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msPerOp > maxExpected {
|
||||||
|
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||||
|
benchCases := []struct {
|
||||||
|
name string
|
||||||
|
peers int
|
||||||
|
groups int
|
||||||
|
// We need different expectations for CI/CD and local runs because of the different performance characteristics
|
||||||
|
minMsPerOpLocal float64
|
||||||
|
maxMsPerOpLocal float64
|
||||||
|
minMsPerOpCICD float64
|
||||||
|
maxMsPerOpCICD float64
|
||||||
|
}{
|
||||||
|
{"Small", 50, 5, 107, 120, 107, 140},
|
||||||
|
{"Medium", 500, 100, 105, 140, 105, 170},
|
||||||
|
{"Large", 5000, 200, 180, 220, 180, 320},
|
||||||
|
{"Small single", 50, 10, 107, 120, 105, 140},
|
||||||
|
{"Medium single", 500, 10, 105, 140, 105, 170},
|
||||||
|
{"Large 5", 5000, 15, 180, 220, 180, 320},
|
||||||
|
}
|
||||||
|
|
||||||
|
log.SetOutput(io.Discard)
|
||||||
|
defer log.SetOutput(os.Stderr)
|
||||||
|
for _, bc := range benchCases {
|
||||||
|
b.Run(bc.name, func(b *testing.B) {
|
||||||
|
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to get account: %v", err)
|
||||||
|
}
|
||||||
|
peerChannels := make(map[string]chan *UpdateMessage)
|
||||||
|
for peerID := range account.Peers {
|
||||||
|
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||||
|
}
|
||||||
|
manager.peersUpdateManager.peerChannels = peerChannels
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
start := time.Now()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{
|
||||||
|
WireGuardPubKey: "some-new-key" + strconv.Itoa(i),
|
||||||
|
SSHKey: "someKey",
|
||||||
|
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
|
||||||
|
UserID: "regular_user",
|
||||||
|
SetupKey: "",
|
||||||
|
ConnectionIP: net.IP{1, 1, 1, 1},
|
||||||
|
})
|
||||||
|
assert.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
duration := time.Since(start)
|
||||||
|
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||||
|
b.ReportMetric(msPerOp, "ms/op")
|
||||||
|
|
||||||
|
minExpected := bc.minMsPerOpLocal
|
||||||
|
maxExpected := bc.maxMsPerOpLocal
|
||||||
|
if os.Getenv("CI") == "true" {
|
||||||
|
minExpected = bc.minMsPerOpCICD
|
||||||
|
maxExpected = bc.maxMsPerOpCICD
|
||||||
|
}
|
||||||
|
|
||||||
|
if msPerOp < minExpected {
|
||||||
|
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msPerOp > maxExpected {
|
||||||
|
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -3,6 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@ -85,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
if user.AccountID != accountID {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
|
return nil, status.NewUserNotPartOfAccountError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.IsRegularUser() {
|
||||||
|
return nil, status.NewAdminPermissionError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
|
||||||
@ -94,64 +99,137 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
|
|||||||
|
|
||||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||||
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !user.HasAdminPower() {
|
|
||||||
return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
|
|
||||||
}
|
|
||||||
|
|
||||||
if dnsSettingsToSave == nil {
|
if dnsSettingsToSave == nil {
|
||||||
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(dnsSettingsToSave.DisabledManagementGroups) != 0 {
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups)
|
if err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
oldSettings := account.DNSSettings.Copy()
|
|
||||||
account.DNSSettings = dnsSettingsToSave.Copy()
|
|
||||||
|
|
||||||
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
|
||||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, id := range addedGroups {
|
if user.AccountID != accountID {
|
||||||
group := account.GetGroup(id)
|
return status.NewUserNotPartOfAccountError()
|
||||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
|
||||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, id := range removedGroups {
|
if !user.HasAdminPower() {
|
||||||
group := account.GetGroup(id)
|
return status.NewAdminPermissionError()
|
||||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
|
||||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
|
var updateAccountPeers bool
|
||||||
|
var eventsToStore []func()
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||||
|
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||||
|
|
||||||
|
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
|
||||||
|
eventsToStore = append(eventsToStore, events...)
|
||||||
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, storeEvent := range eventsToStore {
|
||||||
|
storeEvent()
|
||||||
|
}
|
||||||
|
|
||||||
|
if updateAccountPeers {
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// prepareDNSSettingsEvents prepares a list of event functions to be stored.
|
||||||
|
func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() {
|
||||||
|
var eventsToStore []func()
|
||||||
|
|
||||||
|
modifiedGroups := slices.Concat(addedGroups, removedGroups)
|
||||||
|
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, groupID := range addedGroups {
|
||||||
|
group, ok := groups[groupID]
|
||||||
|
if !ok {
|
||||||
|
log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToDisabledManagementGroups activity", groupID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
eventsToStore = append(eventsToStore, func() {
|
||||||
|
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||||
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, groupID := range removedGroups {
|
||||||
|
group, ok := groups[groupID]
|
||||||
|
if !ok {
|
||||||
|
log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromDisabledManagementGroups activity", groupID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
eventsToStore = append(eventsToStore, func() {
|
||||||
|
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||||
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return eventsToStore
|
||||||
|
}
|
||||||
|
|
||||||
|
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
|
||||||
|
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
|
||||||
|
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasPeers {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return anyGroupHasPeers(ctx, transaction, accountID, removedGroups)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateDNSSettings validates the DNS settings.
|
||||||
|
func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error {
|
||||||
|
if len(settings.DisabledManagementGroups) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return validateGroups(settings.DisabledManagementGroups, groups)
|
||||||
|
}
|
||||||
|
|
||||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||||
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
|
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
|
||||||
protoUpdate := &proto.DNSConfig{
|
protoUpdate := &proto.DNSConfig{
|
||||||
|
@ -566,8 +566,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
||||||
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
|
||||||
for _, groupID := range groupIDs {
|
for _, groupID := range groupIDs {
|
||||||
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
|
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
|
||||||
return true
|
return true
|
||||||
@ -575,3 +574,19 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// anyGroupHasPeers checks if any of the given groups in the account have peers.
|
||||||
|
func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
|
||||||
|
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, group := range groups {
|
||||||
|
if group.HasPeers() {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
@ -500,8 +500,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// adding a group to policy
|
// adding a group to policy
|
||||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
ID: "policy",
|
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
@ -512,7 +511,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, false)
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Saving a group linked to policy should update account peers and send peer update
|
// Saving a group linked to policy should update account peers and send peer update
|
||||||
|
@ -6,10 +6,8 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
|
||||||
"github.com/rs/xid"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
@ -122,21 +120,22 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
isUpdate := policyID != ""
|
policy := &server.Policy{
|
||||||
|
|
||||||
if policyID == "" {
|
|
||||||
policyID = xid.New().String()
|
|
||||||
}
|
|
||||||
|
|
||||||
policy := server.Policy{
|
|
||||||
ID: policyID,
|
ID: policyID,
|
||||||
|
AccountID: accountID,
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Enabled: req.Enabled,
|
Enabled: req.Enabled,
|
||||||
Description: req.Description,
|
Description: req.Description,
|
||||||
}
|
}
|
||||||
for _, rule := range req.Rules {
|
for _, rule := range req.Rules {
|
||||||
|
var ruleID string
|
||||||
|
if rule.Id != nil {
|
||||||
|
ruleID = *rule.Id
|
||||||
|
}
|
||||||
|
|
||||||
pr := server.PolicyRule{
|
pr := server.PolicyRule{
|
||||||
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
|
ID: ruleID,
|
||||||
|
PolicyID: policyID,
|
||||||
Name: rule.Name,
|
Name: rule.Name,
|
||||||
Destinations: rule.Destinations,
|
Destinations: rule.Destinations,
|
||||||
Sources: rule.Sources,
|
Sources: rule.Sources,
|
||||||
@ -225,7 +224,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
|||||||
policy.SourcePostureChecks = *req.SourcePostureChecks
|
policy.SourcePostureChecks = *req.SourcePostureChecks
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
|
policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy)
|
||||||
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -236,7 +236,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := toPolicyResponse(allGroups, &policy)
|
resp := toPolicyResponse(allGroups, policy)
|
||||||
if len(resp.Rules) == 0 {
|
if len(resp.Rules) == 0 {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||||
return
|
return
|
||||||
|
@ -38,12 +38,12 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
|||||||
}
|
}
|
||||||
return policy, nil
|
return policy, nil
|
||||||
},
|
},
|
||||||
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
|
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) {
|
||||||
if !strings.HasPrefix(policy.ID, "id-") {
|
if !strings.HasPrefix(policy.ID, "id-") {
|
||||||
policy.ID = "id-was-set"
|
policy.ID = "id-was-set"
|
||||||
policy.Rules[0].ID = "id-was-set"
|
policy.Rules[0].ID = "id-was-set"
|
||||||
}
|
}
|
||||||
return nil
|
return policy, nil
|
||||||
},
|
},
|
||||||
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
|
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
|
||||||
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
|
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
|
||||||
|
@ -169,7 +169,8 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
|
postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks)
|
||||||
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
|||||||
}
|
}
|
||||||
return p, nil
|
return p, nil
|
||||||
},
|
},
|
||||||
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
|
||||||
postureChecks.ID = "postureCheck"
|
postureChecks.ID = "postureCheck"
|
||||||
testPostureChecks[postureChecks.ID] = postureChecks
|
testPostureChecks[postureChecks.ID] = postureChecks
|
||||||
|
|
||||||
if err := postureChecks.Validate(); err != nil {
|
if err := postureChecks.Validate(); err != nil {
|
||||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return postureChecks, nil
|
||||||
},
|
},
|
||||||
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
|
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
|
||||||
_, ok := testPostureChecks[postureChecksID]
|
_, ok := testPostureChecks[postureChecksID]
|
||||||
|
@ -49,7 +49,7 @@ type MockAccountManager struct {
|
|||||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||||
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
||||||
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
|
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
|
||||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
|
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error)
|
||||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
||||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
||||||
@ -96,7 +96,7 @@ type MockAccountManager struct {
|
|||||||
HasConnectedChannelFunc func(peerID string) bool
|
HasConnectedChannelFunc func(peerID string) bool
|
||||||
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
||||||
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||||
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
|
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
|
||||||
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
|
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||||
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||||
GetIdpManagerFunc func() idp.Manager
|
GetIdpManagerFunc func() idp.Manager
|
||||||
@ -386,11 +386,11 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
|
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
|
||||||
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error {
|
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) {
|
||||||
if am.SavePolicyFunc != nil {
|
if am.SavePolicyFunc != nil {
|
||||||
return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate)
|
return am.SavePolicyFunc(ctx, accountID, userID, policy)
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface
|
// DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface
|
||||||
@ -730,11 +730,11 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
|
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
|
||||||
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
|
||||||
if am.SavePostureChecksFunc != nil {
|
if am.SavePostureChecksFunc != nil {
|
||||||
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
|
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface
|
// DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface
|
||||||
|
@ -24,26 +24,34 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
if user.AccountID != accountID {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
return nil, status.NewUserNotPartOfAccountError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
|
if user.IsRegularUser() {
|
||||||
|
return nil, status.NewAdminPermissionError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateNameServerGroup creates and saves a new nameserver group
|
// CreateNameServerGroup creates and saves a new nameserver group
|
||||||
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||||
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if user.AccountID != accountID {
|
||||||
|
return nil, status.NewUserNotPartOfAccountError()
|
||||||
|
}
|
||||||
|
|
||||||
newNSGroup := &nbdns.NameServerGroup{
|
newNSGroup := &nbdns.NameServerGroup{
|
||||||
ID: xid.New().String(),
|
ID: xid.New().String(),
|
||||||
|
AccountID: accountID,
|
||||||
Name: name,
|
Name: name,
|
||||||
Description: description,
|
Description: description,
|
||||||
NameServers: nameServerList,
|
NameServers: nameServerList,
|
||||||
@ -54,26 +62,33 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
|||||||
SearchDomainsEnabled: searchDomainEnabled,
|
SearchDomainsEnabled: searchDomainEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validateNameServerGroup(false, newNSGroup, account)
|
var updateAccountPeers bool
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, newNSGroup.Groups)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.NameServerGroups == nil {
|
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||||
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup)
|
|
||||||
}
|
|
||||||
|
|
||||||
account.NameServerGroups[newNSGroup.ID] = newNSGroup
|
if updateAccountPeers {
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if anyGroupHasPeers(account, newNSGroup.Groups) {
|
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
|
||||||
|
|
||||||
return newNSGroup.Copy(), nil
|
return newNSGroup.Copy(), nil
|
||||||
}
|
}
|
||||||
@ -87,58 +102,95 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
|||||||
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
|
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validateNameServerGroup(true, nsGroupToSave, account)
|
if user.AccountID != accountID {
|
||||||
|
return status.NewUserNotPartOfAccountError()
|
||||||
|
}
|
||||||
|
|
||||||
|
var updateAccountPeers bool
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
nsGroupToSave.AccountID = accountID
|
||||||
|
|
||||||
|
if err = validateNameServerGroup(ctx, transaction, accountID, nsGroupToSave); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
oldNSGroup := account.NameServerGroups[nsGroupToSave.ID]
|
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||||
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
if updateAccountPeers {
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
|
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
||||||
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
|
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
|
||||||
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroup := account.NameServerGroups[nsGroupID]
|
if user.AccountID != accountID {
|
||||||
if nsGroup == nil {
|
return status.NewUserNotPartOfAccountError()
|
||||||
return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID)
|
|
||||||
}
|
}
|
||||||
delete(account.NameServerGroups, nsGroupID)
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
var nsGroup *nbdns.NameServerGroup
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
var updateAccountPeers bool
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
nsGroup, err = transaction.GetNameServerGroupByID(ctx, LockingStrengthUpdate, accountID, nsGroupID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, nsGroup.Groups)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if anyGroupHasPeers(account, nsGroup.Groups) {
|
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||||
|
|
||||||
|
if updateAccountPeers {
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -150,44 +202,62 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
if user.AccountID != accountID {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
return nil, status.NewUserNotPartOfAccountError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.IsRegularUser() {
|
||||||
|
return nil, status.NewAdminPermissionError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
|
func validateNameServerGroup(ctx context.Context, transaction Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
|
||||||
nsGroupID := ""
|
|
||||||
if existingGroup {
|
|
||||||
nsGroupID = nameserverGroup.ID
|
|
||||||
_, found := account.NameServerGroups[nsGroupID]
|
|
||||||
if !found {
|
|
||||||
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
|
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = validateNSList(nameserverGroup.NameServers)
|
err = validateNSList(nameserverGroup.NameServers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validateGroups(nameserverGroup.Groups, account.Groups)
|
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
err = validateNSGroupName(nameserverGroup.Name, nameserverGroup.ID, nsServerGroups)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, nameserverGroup.Groups)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return validateGroups(nameserverGroup.Groups, groups)
|
||||||
|
}
|
||||||
|
|
||||||
|
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
|
||||||
|
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
|
||||||
|
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hasPeers, err := anyGroupHasPeers(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasPeers {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return anyGroupHasPeers(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
|
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
|
||||||
@ -213,14 +283,14 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
|
func validateNSGroupName(name, nsGroupID string, groups []*nbdns.NameServerGroup) error {
|
||||||
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
|
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
|
||||||
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
|
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, nsGroup := range nsGroupMap {
|
for _, nsGroup := range groups {
|
||||||
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
|
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
|
||||||
return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name)
|
return status.Errorf(status.InvalidArgument, "nameserver group with name %s already exist", name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,8 +298,8 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na
|
|||||||
}
|
}
|
||||||
|
|
||||||
func validateNSList(list []nbdns.NameServer) error {
|
func validateNSList(list []nbdns.NameServer) error {
|
||||||
nsListLenght := len(list)
|
nsListLength := len(list)
|
||||||
if nsListLenght == 0 || nsListLenght > 3 {
|
if nsListLength == 0 || nsListLength > 3 {
|
||||||
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list))
|
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -244,14 +314,7 @@ func validateGroups(list []string, groups map[string]*nbgroup.Group) error {
|
|||||||
if id == "" {
|
if id == "" {
|
||||||
return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
|
return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
|
||||||
}
|
}
|
||||||
found := false
|
if _, found := groups[id]; !found {
|
||||||
for groupID := range groups {
|
|
||||||
if id == groupID {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
return status.Errorf(status.InvalidArgument, "group id %s not found", id)
|
return status.Errorf(status.InvalidArgument, "group id %s not found", id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -277,11 +340,3 @@ func validateDomain(domain string) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
|
|
||||||
func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
|
|
||||||
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups)
|
|
||||||
}
|
|
||||||
|
@ -617,7 +617,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks := am.getPeerPostureChecks(account, newPeer)
|
postureChecks, err := am.getPeerPostureChecks(account, newPeer.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||||
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||||
return newPeer, networkMap, postureChecks, nil
|
return newPeer, networkMap, postureChecks, nil
|
||||||
@ -702,7 +706,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
|
return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
|
||||||
}
|
}
|
||||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
|
||||||
|
postureChecks, err = am.getPeerPostureChecks(account, peer.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
||||||
@ -876,7 +884,11 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
|
||||||
|
postureChecks, err = am.getPeerPostureChecks(account, peer.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
||||||
@ -1030,7 +1042,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
|||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
defer func() { <-semaphore }()
|
defer func() { <-semaphore }()
|
||||||
|
|
||||||
postureChecks := am.getPeerPostureChecks(account, p)
|
postureChecks, err := am.getPeerPostureChecks(account, p.ID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
|
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
|
||||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||||
|
@ -283,14 +283,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
|||||||
var (
|
var (
|
||||||
group1 nbgroup.Group
|
group1 nbgroup.Group
|
||||||
group2 nbgroup.Group
|
group2 nbgroup.Group
|
||||||
policy Policy
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group1.ID = xid.New().String()
|
group1.ID = xid.New().String()
|
||||||
group2.ID = xid.New().String()
|
group2.ID = xid.New().String()
|
||||||
group1.Name = "src"
|
group1.Name = "src"
|
||||||
group2.Name = "dst"
|
group2.Name = "dst"
|
||||||
policy.ID = xid.New().String()
|
|
||||||
group1.Peers = append(group1.Peers, peer1.ID)
|
group1.Peers = append(group1.Peers, peer1.ID)
|
||||||
group2.Peers = append(group2.Peers, peer2.ID)
|
group2.Peers = append(group2.Peers, peer2.ID)
|
||||||
|
|
||||||
@ -305,18 +303,20 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
policy.Name = "test"
|
policy := &Policy{
|
||||||
policy.Enabled = true
|
Name: "test",
|
||||||
policy.Rules = []*PolicyRule{
|
Enabled: true,
|
||||||
{
|
Rules: []*PolicyRule{
|
||||||
Enabled: true,
|
{
|
||||||
Sources: []string{group1.ID},
|
Enabled: true,
|
||||||
Destinations: []string{group2.ID},
|
Sources: []string{group1.ID},
|
||||||
Bidirectional: true,
|
Destinations: []string{group2.ID},
|
||||||
Action: PolicyTrafficActionAccept,
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("expecting rule to be added, got failure %v", err)
|
t.Errorf("expecting rule to be added, got failure %v", err)
|
||||||
return
|
return
|
||||||
@ -364,7 +364,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
policy.Enabled = false
|
policy.Enabled = false
|
||||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("expecting rule to be added, got failure %v", err)
|
t.Errorf("expecting rule to be added, got failure %v", err)
|
||||||
return
|
return
|
||||||
@ -833,19 +833,23 @@ func BenchmarkGetPeers(b *testing.B) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkUpdateAccountPeers(b *testing.B) {
|
func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||||
benchCases := []struct {
|
benchCases := []struct {
|
||||||
name string
|
name string
|
||||||
peers int
|
peers int
|
||||||
groups int
|
groups int
|
||||||
|
// We need different expectations for CI/CD and local runs because of the different performance characteristics
|
||||||
|
minMsPerOpLocal float64
|
||||||
|
maxMsPerOpLocal float64
|
||||||
|
minMsPerOpCICD float64
|
||||||
|
maxMsPerOpCICD float64
|
||||||
}{
|
}{
|
||||||
{"Small", 50, 5},
|
{"Small", 50, 5, 90, 120, 90, 120},
|
||||||
{"Medium", 500, 10},
|
{"Medium", 500, 100, 110, 140, 120, 200},
|
||||||
{"Large", 5000, 20},
|
{"Large", 5000, 200, 800, 1300, 2500, 3600},
|
||||||
{"Small single", 50, 1},
|
{"Small single", 50, 10, 90, 120, 90, 120},
|
||||||
{"Medium single", 500, 1},
|
{"Medium single", 500, 10, 110, 170, 120, 200},
|
||||||
{"Large 5", 5000, 5},
|
{"Large 5", 5000, 15, 1300, 1800, 5000, 6000},
|
||||||
}
|
}
|
||||||
|
|
||||||
log.SetOutput(io.Discard)
|
log.SetOutput(io.Discard)
|
||||||
@ -881,8 +885,23 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
duration := time.Since(start)
|
duration := time.Since(start)
|
||||||
b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op")
|
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||||
b.ReportMetric(0, "ns/op")
|
b.ReportMetric(msPerOp, "ms/op")
|
||||||
|
|
||||||
|
minExpected := bc.minMsPerOpLocal
|
||||||
|
maxExpected := bc.maxMsPerOpLocal
|
||||||
|
if os.Getenv("CI") == "true" {
|
||||||
|
minExpected = bc.minMsPerOpCICD
|
||||||
|
maxExpected = bc.maxMsPerOpCICD
|
||||||
|
}
|
||||||
|
|
||||||
|
if msPerOp < minExpected {
|
||||||
|
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msPerOp > maxExpected {
|
||||||
|
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1445,8 +1464,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
// Adding peer to group linked with policy should update account peers and send peer update
|
// Adding peer to group linked with policy should update account peers and send peer update
|
||||||
t.Run("adding peer to group linked with policy", func(t *testing.T) {
|
t.Run("adding peer to group linked with policy", func(t *testing.T) {
|
||||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
ID: "policy",
|
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
@ -1457,7 +1475,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, false)
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
@ -3,13 +3,13 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
_ "embed"
|
_ "embed"
|
||||||
"slices"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
@ -125,6 +125,7 @@ type PolicyRule struct {
|
|||||||
func (pm *PolicyRule) Copy() *PolicyRule {
|
func (pm *PolicyRule) Copy() *PolicyRule {
|
||||||
rule := &PolicyRule{
|
rule := &PolicyRule{
|
||||||
ID: pm.ID,
|
ID: pm.ID,
|
||||||
|
PolicyID: pm.PolicyID,
|
||||||
Name: pm.Name,
|
Name: pm.Name,
|
||||||
Description: pm.Description,
|
Description: pm.Description,
|
||||||
Enabled: pm.Enabled,
|
Enabled: pm.Enabled,
|
||||||
@ -171,6 +172,7 @@ type Policy struct {
|
|||||||
func (p *Policy) Copy() *Policy {
|
func (p *Policy) Copy() *Policy {
|
||||||
c := &Policy{
|
c := &Policy{
|
||||||
ID: p.ID,
|
ID: p.ID,
|
||||||
|
AccountID: p.AccountID,
|
||||||
Name: p.Name,
|
Name: p.Name,
|
||||||
Description: p.Description,
|
Description: p.Description,
|
||||||
Enabled: p.Enabled,
|
Enabled: p.Enabled,
|
||||||
@ -343,157 +345,207 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
if user.AccountID != accountID {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
return nil, status.NewUserNotPartOfAccountError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
|
if user.IsRegularUser() {
|
||||||
|
return nil, status.NewAdminPermissionError()
|
||||||
|
}
|
||||||
|
|
||||||
|
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SavePolicy in the store
|
// SavePolicy in the store
|
||||||
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
|
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
updateAccountPeers, err := am.savePolicy(account, policy, isUpdate)
|
if user.AccountID != accountID {
|
||||||
|
return nil, status.NewUserNotPartOfAccountError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.IsRegularUser() {
|
||||||
|
return nil, status.NewAdminPermissionError()
|
||||||
|
}
|
||||||
|
|
||||||
|
var isUpdate = policy.ID != ""
|
||||||
|
var updateAccountPeers bool
|
||||||
|
var action = activity.PolicyAdded
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
saveFunc := transaction.CreatePolicy
|
||||||
|
if isUpdate {
|
||||||
|
action = activity.PolicyUpdated
|
||||||
|
saveFunc = transaction.SavePolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
return saveFunc(ctx, LockingStrengthUpdate, policy)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
action := activity.PolicyAdded
|
|
||||||
if isUpdate {
|
|
||||||
action = activity.PolicyUpdated
|
|
||||||
}
|
|
||||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||||
|
|
||||||
if updateAccountPeers {
|
if updateAccountPeers {
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return policy, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePolicy from the store
|
||||||
|
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
|
||||||
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.AccountID != accountID {
|
||||||
|
return status.NewUserNotPartOfAccountError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.IsRegularUser() {
|
||||||
|
return status.NewAdminPermissionError()
|
||||||
|
}
|
||||||
|
|
||||||
|
var policy *Policy
|
||||||
|
var updateAccountPeers bool
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||||
|
|
||||||
|
if updateAccountPeers {
|
||||||
|
am.updateAccountPeers(ctx, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePolicy from the store
|
// ListPolicies from the store.
|
||||||
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
policy, err := am.deletePolicy(account, policyID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Network.IncSerial()
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
|
||||||
|
|
||||||
if anyGroupHasPeers(account, policy.ruleGroups()) {
|
|
||||||
am.updateAccountPeers(ctx, accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListPolicies from the store
|
|
||||||
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
|
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
|
||||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
|
if user.AccountID != accountID {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
return nil, status.NewUserNotPartOfAccountError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.IsRegularUser() {
|
||||||
|
return nil, status.NewAdminPermissionError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
|
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
|
||||||
policyIdx := -1
|
func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) {
|
||||||
for i, policy := range account.Policies {
|
|
||||||
if policy.ID == policyID {
|
|
||||||
policyIdx = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if policyIdx < 0 {
|
|
||||||
return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID)
|
|
||||||
}
|
|
||||||
|
|
||||||
policy := account.Policies[policyIdx]
|
|
||||||
account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...)
|
|
||||||
return policy, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// savePolicy saves or updates a policy in the given account.
|
|
||||||
// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
|
|
||||||
func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) {
|
|
||||||
for index, rule := range policyToSave.Rules {
|
|
||||||
rule.Sources = filterValidGroupIDs(account, rule.Sources)
|
|
||||||
rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
|
|
||||||
policyToSave.Rules[index] = rule
|
|
||||||
}
|
|
||||||
|
|
||||||
if policyToSave.SourcePostureChecks != nil {
|
|
||||||
policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
|
|
||||||
}
|
|
||||||
|
|
||||||
if isUpdate {
|
if isUpdate {
|
||||||
policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
|
existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
|
||||||
if policyIdx < 0 {
|
if err != nil {
|
||||||
return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
oldPolicy := account.Policies[policyIdx]
|
if !policy.Enabled && !existingPolicy.Enabled {
|
||||||
// Update the existing policy
|
|
||||||
account.Policies[policyIdx] = policyToSave
|
|
||||||
|
|
||||||
if !policyToSave.Enabled && !oldPolicy.Enabled {
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
|
|
||||||
|
|
||||||
return updateAccountPeers, nil
|
hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups())
|
||||||
}
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
// Add the new policy to the account
|
if hasPeers {
|
||||||
account.Policies = append(account.Policies, policyToSave)
|
return true, nil
|
||||||
|
|
||||||
return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
|
|
||||||
result := make([]*proto.FirewallRule, len(rules))
|
|
||||||
for i := range rules {
|
|
||||||
rule := rules[i]
|
|
||||||
|
|
||||||
result[i] = &proto.FirewallRule{
|
|
||||||
PeerIP: rule.PeerIP,
|
|
||||||
Direction: getProtoDirection(rule.Direction),
|
|
||||||
Action: getProtoAction(rule.Action),
|
|
||||||
Protocol: getProtoProtocol(rule.Protocol),
|
|
||||||
Port: rule.Port,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result
|
|
||||||
|
return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups())
|
||||||
|
}
|
||||||
|
|
||||||
|
// validatePolicy validates the policy and its rules.
|
||||||
|
func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error {
|
||||||
|
if policy.ID != "" {
|
||||||
|
_, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
policy.ID = xid.New().String()
|
||||||
|
policy.AccountID = accountID
|
||||||
|
}
|
||||||
|
|
||||||
|
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, rule := range policy.Rules {
|
||||||
|
ruleCopy := rule.Copy()
|
||||||
|
if ruleCopy.ID == "" {
|
||||||
|
ruleCopy.ID = policy.ID // TODO: when policy can contain multiple rules, need refactor
|
||||||
|
ruleCopy.PolicyID = policy.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleCopy.Sources = getValidGroupIDs(groups, ruleCopy.Sources)
|
||||||
|
ruleCopy.Destinations = getValidGroupIDs(groups, ruleCopy.Destinations)
|
||||||
|
policy.Rules[i] = ruleCopy
|
||||||
|
}
|
||||||
|
|
||||||
|
if policy.SourcePostureChecks != nil {
|
||||||
|
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAllPeersFromGroups for given peer ID and list of groups
|
// getAllPeersFromGroups for given peer ID and list of groups
|
||||||
@ -574,27 +626,42 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterValidPostureChecks filters and returns the posture check IDs from the given list
|
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
||||||
// that are valid within the provided account.
|
func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string {
|
||||||
func filterValidPostureChecks(account *Account, postureChecksIds []string) []string {
|
validIDs := make([]string, 0, len(postureChecksIds))
|
||||||
result := make([]string, 0, len(postureChecksIds))
|
|
||||||
for _, id := range postureChecksIds {
|
for _, id := range postureChecksIds {
|
||||||
for _, postureCheck := range account.PostureChecks {
|
if _, exists := postureChecks[id]; exists {
|
||||||
if id == postureCheck.ID {
|
validIDs = append(validIDs, id)
|
||||||
result = append(result, id)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result
|
|
||||||
|
return validIDs
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map.
|
// getValidGroupIDs filters and returns only the valid group IDs from the provided list.
|
||||||
func filterValidGroupIDs(account *Account, groupIDs []string) []string {
|
func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string {
|
||||||
result := make([]string, 0, len(groupIDs))
|
validIDs := make([]string, 0, len(groupIDs))
|
||||||
for _, groupID := range groupIDs {
|
for _, id := range groupIDs {
|
||||||
if _, exists := account.Groups[groupID]; exists {
|
if _, exists := groups[id]; exists {
|
||||||
result = append(result, groupID)
|
validIDs = append(validIDs, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return validIDs
|
||||||
|
}
|
||||||
|
|
||||||
|
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
|
||||||
|
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
|
||||||
|
result := make([]*proto.FirewallRule, len(rules))
|
||||||
|
for i := range rules {
|
||||||
|
rule := rules[i]
|
||||||
|
|
||||||
|
result[i] = &proto.FirewallRule{
|
||||||
|
PeerIP: rule.PeerIP,
|
||||||
|
Direction: getProtoDirection(rule.Direction),
|
||||||
|
Action: getProtoAction(rule.Action),
|
||||||
|
Protocol: getProtoProtocol(rule.Protocol),
|
||||||
|
Port: rule.Port,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
@ -7,7 +7,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.org/x/exp/slices"
|
"golang.org/x/exp/slices"
|
||||||
|
|
||||||
@ -859,14 +858,23 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
|||||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
var policyWithGroupRulesNoPeers *Policy
|
||||||
|
var policyWithDestinationPeersOnly *Policy
|
||||||
|
var policyWithSourceAndDestinationPeers *Policy
|
||||||
|
|
||||||
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
|
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
|
||||||
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
|
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
|
||||||
policy := Policy{
|
done := make(chan struct{})
|
||||||
ID: "policy-rule-groups-no-peers",
|
go func() {
|
||||||
Enabled: true,
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
|
AccountID: account.Id,
|
||||||
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
ID: xid.New().String(),
|
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Sources: []string{"groupB"},
|
Sources: []string{"groupB"},
|
||||||
Destinations: []string{"groupC"},
|
Destinations: []string{"groupC"},
|
||||||
@ -874,15 +882,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
})
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
peerShouldNotReceiveUpdate(t, updMsg)
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -895,12 +895,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
|||||||
// Saving policy with source group containing peers, but destination group without peers should
|
// Saving policy with source group containing peers, but destination group without peers should
|
||||||
// update account's peers and send peer update
|
// update account's peers and send peer update
|
||||||
t.Run("saving policy where source has peers but destination does not", func(t *testing.T) {
|
t.Run("saving policy where source has peers but destination does not", func(t *testing.T) {
|
||||||
policy := Policy{
|
done := make(chan struct{})
|
||||||
ID: "policy-source-has-peers-destination-none",
|
go func() {
|
||||||
Enabled: true,
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
|
AccountID: account.Id,
|
||||||
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
ID: xid.New().String(),
|
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Sources: []string{"groupA"},
|
Sources: []string{"groupA"},
|
||||||
Destinations: []string{"groupB"},
|
Destinations: []string{"groupB"},
|
||||||
@ -909,15 +914,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
})
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
peerShouldReceiveUpdate(t, updMsg)
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -930,13 +927,18 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
|||||||
// Saving policy with destination group containing peers, but source group without peers should
|
// Saving policy with destination group containing peers, but source group without peers should
|
||||||
// update account's peers and send peer update
|
// update account's peers and send peer update
|
||||||
t.Run("saving policy where destination has peers but source does not", func(t *testing.T) {
|
t.Run("saving policy where destination has peers but source does not", func(t *testing.T) {
|
||||||
policy := Policy{
|
done := make(chan struct{})
|
||||||
ID: "policy-destination-has-peers-source-none",
|
go func() {
|
||||||
Enabled: true,
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
|
AccountID: account.Id,
|
||||||
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
ID: xid.New().String(),
|
Enabled: true,
|
||||||
Enabled: false,
|
|
||||||
Sources: []string{"groupC"},
|
Sources: []string{"groupC"},
|
||||||
Destinations: []string{"groupD"},
|
Destinations: []string{"groupD"},
|
||||||
Bidirectional: true,
|
Bidirectional: true,
|
||||||
@ -944,15 +946,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
})
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
peerShouldReceiveUpdate(t, updMsg)
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -965,12 +959,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
|||||||
// Saving policy with destination and source groups containing peers should update account's peers
|
// Saving policy with destination and source groups containing peers should update account's peers
|
||||||
// and send peer update
|
// and send peer update
|
||||||
t.Run("saving policy with source and destination groups with peers", func(t *testing.T) {
|
t.Run("saving policy with source and destination groups with peers", func(t *testing.T) {
|
||||||
policy := Policy{
|
done := make(chan struct{})
|
||||||
ID: "policy-source-destination-peers",
|
go func() {
|
||||||
Enabled: true,
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
|
AccountID: account.Id,
|
||||||
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
ID: xid.New().String(),
|
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Sources: []string{"groupA"},
|
Sources: []string{"groupA"},
|
||||||
Destinations: []string{"groupD"},
|
Destinations: []string{"groupD"},
|
||||||
@ -978,15 +977,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
})
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
peerShouldReceiveUpdate(t, updMsg)
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -999,28 +990,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
|||||||
// Disabling policy with destination and source groups containing peers should update account's peers
|
// Disabling policy with destination and source groups containing peers should update account's peers
|
||||||
// and send peer update
|
// and send peer update
|
||||||
t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) {
|
t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) {
|
||||||
policy := Policy{
|
|
||||||
ID: "policy-source-destination-peers",
|
|
||||||
Enabled: false,
|
|
||||||
Rules: []*PolicyRule{
|
|
||||||
{
|
|
||||||
ID: xid.New().String(),
|
|
||||||
Enabled: true,
|
|
||||||
Sources: []string{"groupA"},
|
|
||||||
Destinations: []string{"groupD"},
|
|
||||||
Bidirectional: true,
|
|
||||||
Action: PolicyTrafficActionAccept,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
peerShouldReceiveUpdate(t, updMsg)
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
policyWithSourceAndDestinationPeers.Enabled = false
|
||||||
|
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -1033,29 +1010,15 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
|||||||
// Updating disabled policy with destination and source groups containing peers should not update account's peers
|
// Updating disabled policy with destination and source groups containing peers should not update account's peers
|
||||||
// or send peer update
|
// or send peer update
|
||||||
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
|
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
|
||||||
policy := Policy{
|
|
||||||
ID: "policy-source-destination-peers",
|
|
||||||
Description: "updated description",
|
|
||||||
Enabled: false,
|
|
||||||
Rules: []*PolicyRule{
|
|
||||||
{
|
|
||||||
ID: xid.New().String(),
|
|
||||||
Enabled: true,
|
|
||||||
Sources: []string{"groupA"},
|
|
||||||
Destinations: []string{"groupA"},
|
|
||||||
Bidirectional: true,
|
|
||||||
Action: PolicyTrafficActionAccept,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
peerShouldNotReceiveUpdate(t, updMsg)
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
policyWithSourceAndDestinationPeers.Description = "updated description"
|
||||||
|
policyWithSourceAndDestinationPeers.Rules[0].Destinations = []string{"groupA"}
|
||||||
|
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -1068,28 +1031,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
|||||||
// Enabling policy with destination and source groups containing peers should update account's peers
|
// Enabling policy with destination and source groups containing peers should update account's peers
|
||||||
// and send peer update
|
// and send peer update
|
||||||
t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) {
|
t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) {
|
||||||
policy := Policy{
|
|
||||||
ID: "policy-source-destination-peers",
|
|
||||||
Enabled: true,
|
|
||||||
Rules: []*PolicyRule{
|
|
||||||
{
|
|
||||||
ID: xid.New().String(),
|
|
||||||
Enabled: true,
|
|
||||||
Sources: []string{"groupA"},
|
|
||||||
Destinations: []string{"groupD"},
|
|
||||||
Bidirectional: true,
|
|
||||||
Action: PolicyTrafficActionAccept,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
peerShouldReceiveUpdate(t, updMsg)
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
policyWithSourceAndDestinationPeers.Enabled = true
|
||||||
|
policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -1101,15 +1050,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
// Deleting policy should trigger account peers update and send peer update
|
// Deleting policy should trigger account peers update and send peer update
|
||||||
t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) {
|
t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) {
|
||||||
policyID := "policy-source-destination-peers"
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
peerShouldReceiveUpdate(t, updMsg)
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
|
err := manager.DeletePolicy(context.Background(), account.Id, policyWithSourceAndDestinationPeers.ID, userID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -1123,14 +1070,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
|||||||
// Deleting policy with destination group containing peers, but source group without peers should
|
// Deleting policy with destination group containing peers, but source group without peers should
|
||||||
// update account's peers and send peer update
|
// update account's peers and send peer update
|
||||||
t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) {
|
t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) {
|
||||||
policyID := "policy-destination-has-peers-source-none"
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
peerShouldReceiveUpdate(t, updMsg)
|
peerShouldReceiveUpdate(t, updMsg)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
|
err := manager.DeletePolicy(context.Background(), account.Id, policyWithDestinationPeersOnly.ID, userID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -1142,14 +1088,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
// Deleting policy with no peers in groups should not update account's peers and not send peer update
|
// Deleting policy with no peers in groups should not update account's peers and not send peer update
|
||||||
t.Run("deleting policy with no peers in groups", func(t *testing.T) {
|
t.Run("deleting policy with no peers in groups", func(t *testing.T) {
|
||||||
policyID := "policy-rule-groups-no-peers"
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
peerShouldNotReceiveUpdate(t, updMsg)
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
|
err := manager.DeletePolicy(context.Background(), account.Id, policyWithGroupRulesNoPeers.ID, userID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
@ -7,8 +7,6 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
|
|
||||||
"github.com/hashicorp/go-version"
|
"github.com/hashicorp/go-version"
|
||||||
"github.com/rs/xid"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
@ -172,10 +170,6 @@ func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureCh
|
|||||||
}
|
}
|
||||||
|
|
||||||
func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) {
|
func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) {
|
||||||
if postureChecksID == "" {
|
|
||||||
postureChecksID = xid.New().String()
|
|
||||||
}
|
|
||||||
|
|
||||||
postureChecks := Checks{
|
postureChecks := Checks{
|
||||||
ID: postureChecksID,
|
ID: postureChecksID,
|
||||||
Name: name,
|
Name: name,
|
||||||
|
@ -2,237 +2,285 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
|
"github.com/rs/xid"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.HasAdminPower() || user.AccountID != accountID {
|
if user.AccountID != accountID {
|
||||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
return nil, status.NewUserNotPartOfAccountError()
|
||||||
}
|
|
||||||
|
|
||||||
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.HasAdminPower() {
|
if !user.HasAdminPower() {
|
||||||
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
return nil, status.NewAdminPermissionError()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := postureChecks.Validate(); err != nil {
|
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
|
||||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
}
|
||||||
|
|
||||||
|
// SavePostureChecks saves a posture check.
|
||||||
|
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
|
||||||
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
exists, uniqName := am.savePostureChecks(account, postureChecks)
|
if user.AccountID != accountID {
|
||||||
|
return nil, status.NewUserNotPartOfAccountError()
|
||||||
// we do not allow create new posture checks with non uniq name
|
|
||||||
if !exists && !uniqName {
|
|
||||||
return status.Errorf(status.PreconditionFailed, "Posture check name should be unique")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
action := activity.PostureCheckCreated
|
if !user.HasAdminPower() {
|
||||||
if exists {
|
return nil, status.NewAdminPermissionError()
|
||||||
action = activity.PostureCheckUpdated
|
|
||||||
account.Network.IncSerial()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
var updateAccountPeers bool
|
||||||
return err
|
var isUpdate = postureChecks.ID != ""
|
||||||
|
var action = activity.PostureCheckCreated
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if isUpdate {
|
||||||
|
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
action = activity.PostureCheckUpdated
|
||||||
|
}
|
||||||
|
|
||||||
|
postureChecks.AccountID = accountID
|
||||||
|
return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
||||||
|
|
||||||
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
|
if updateAccountPeers {
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return postureChecks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeletePostureChecks deletes a posture check by ID.
|
||||||
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
|
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
if user.AccountID != accountID {
|
||||||
if err != nil {
|
return status.NewUserNotPartOfAccountError()
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.HasAdminPower() {
|
if !user.HasAdminPower() {
|
||||||
return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
return status.NewAdminPermissionError()
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks, err := am.deletePostureChecks(account, postureChecksID)
|
var postureChecks *posture.Checks
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = isPostureCheckLinkedToPolicy(ctx, transaction, postureChecksID, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta())
|
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListPostureChecks returns a list of posture checks.
|
||||||
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
||||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.HasAdminPower() || user.AccountID != accountID {
|
if user.AccountID != accountID {
|
||||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
return nil, status.NewUserNotPartOfAccountError()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user.HasAdminPower() {
|
||||||
|
return nil, status.NewAdminPermissionError()
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
|
|
||||||
uniqName = true
|
|
||||||
for i, p := range account.PostureChecks {
|
|
||||||
if !exists && p.ID == postureChecks.ID {
|
|
||||||
account.PostureChecks[i] = postureChecks
|
|
||||||
exists = true
|
|
||||||
}
|
|
||||||
if p.Name == postureChecks.Name {
|
|
||||||
uniqName = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !exists {
|
|
||||||
account.PostureChecks = append(account.PostureChecks, postureChecks)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) {
|
|
||||||
postureChecksIdx := -1
|
|
||||||
for i, postureChecks := range account.PostureChecks {
|
|
||||||
if postureChecks.ID == postureChecksID {
|
|
||||||
postureChecksIdx = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if postureChecksIdx < 0 {
|
|
||||||
return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if posture check is linked to any policy
|
|
||||||
if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked {
|
|
||||||
return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
postureChecks := account.PostureChecks[postureChecksIdx]
|
|
||||||
account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...)
|
|
||||||
|
|
||||||
return postureChecks, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
// getPeerPostureChecks returns the posture checks applied for a given peer.
|
||||||
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks {
|
func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID string) ([]*posture.Checks, error) {
|
||||||
peerPostureChecks := make(map[string]posture.Checks)
|
peerPostureChecks := make(map[string]*posture.Checks)
|
||||||
|
|
||||||
if len(account.PostureChecks) == 0 {
|
if len(account.PostureChecks) == 0 {
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, policy := range account.Policies {
|
for _, policy := range account.Policies {
|
||||||
if !policy.Enabled {
|
if !policy.Enabled || len(policy.SourcePostureChecks) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if isPeerInPolicySourceGroups(peer.ID, account, policy) {
|
if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil {
|
||||||
addPolicyPostureChecks(account, policy, peerPostureChecks)
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks))
|
return maps.Values(peerPostureChecks), nil
|
||||||
for _, check := range peerPostureChecks {
|
}
|
||||||
checkCopy := check
|
|
||||||
postureChecksList = append(postureChecksList, &checkCopy)
|
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
|
||||||
|
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) {
|
||||||
|
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return postureChecksList
|
for _, policy := range policies {
|
||||||
|
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
|
||||||
|
hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups())
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasPeers {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validatePostureChecks validates the posture checks.
|
||||||
|
func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error {
|
||||||
|
if err := postureChecks.Validate(); err != nil {
|
||||||
|
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the posture check already has an ID, verify its existence in the store.
|
||||||
|
if postureChecks.ID != "" {
|
||||||
|
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// For new posture checks, ensure no duplicates by name.
|
||||||
|
checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, check := range checks {
|
||||||
|
if check.Name == postureChecks.Name && check.ID != postureChecks.ID {
|
||||||
|
return status.Errorf(status.InvalidArgument, "posture checks with name %s already exists", postureChecks.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
postureChecks.ID = xid.New().String()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
|
||||||
|
func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error {
|
||||||
|
isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isInGroup {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
||||||
|
postureCheck := account.getPostureChecks(sourcePostureCheckID)
|
||||||
|
if postureCheck == nil {
|
||||||
|
return errors.New("failed to add policy posture checks: posture checks not found")
|
||||||
|
}
|
||||||
|
peerPostureChecks[sourcePostureCheckID] = postureCheck
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
|
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
|
||||||
func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool {
|
func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy) (bool, error) {
|
||||||
for _, rule := range policy.Rules {
|
for _, rule := range policy.Rules {
|
||||||
if !rule.Enabled {
|
if !rule.Enabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, sourceGroup := range rule.Sources {
|
for _, sourceGroup := range rule.Sources {
|
||||||
group, ok := account.Groups[sourceGroup]
|
group := account.GetGroup(sourceGroup)
|
||||||
if ok && slices.Contains(group.Peers, peerID) {
|
if group == nil {
|
||||||
return true
|
return false, fmt.Errorf("failed to check peer in policy source group: group not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.Contains(group.Peers, peerID) {
|
||||||
|
return true, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) {
|
|
||||||
for _, sourcePostureCheckID := range policy.SourcePostureChecks {
|
|
||||||
for _, postureCheck := range account.PostureChecks {
|
|
||||||
if postureCheck.ID == sourcePostureCheckID {
|
|
||||||
peerPostureChecks[sourcePostureCheckID] = *postureCheck
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) {
|
|
||||||
for _, policy := range account.Policies {
|
|
||||||
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
|
|
||||||
return true, policy
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers.
|
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
|
||||||
func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool {
|
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error {
|
||||||
if !exists {
|
policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
|
||||||
return false
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID)
|
for _, policy := range policies {
|
||||||
if !isLinked {
|
if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
|
||||||
return false
|
return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return anyGroupHasPeers(account, linkedPolicy.ruleGroups())
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -5,8 +5,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/group"
|
"github.com/netbirdio/netbird/management/server/group"
|
||||||
|
|
||||||
@ -16,7 +16,6 @@ import (
|
|||||||
const (
|
const (
|
||||||
adminUserID = "adminUserID"
|
adminUserID = "adminUserID"
|
||||||
regularUserID = "regularUserID"
|
regularUserID = "regularUserID"
|
||||||
postureCheckID = "existing-id"
|
|
||||||
postureCheckName = "Existing check"
|
postureCheckName = "Existing check"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -33,7 +32,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("Generic posture check flow", func(t *testing.T) {
|
t.Run("Generic posture check flow", func(t *testing.T) {
|
||||||
// regular users can not create checks
|
// regular users can not create checks
|
||||||
err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{})
|
_, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
// regular users cannot list check
|
// regular users cannot list check
|
||||||
@ -41,8 +40,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
|||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
// should be possible to create posture check with uniq name
|
// should be possible to create posture check with uniq name
|
||||||
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
postureCheck, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||||
ID: postureCheckID,
|
|
||||||
Name: postureCheckName,
|
Name: postureCheckName,
|
||||||
Checks: posture.ChecksDefinition{
|
Checks: posture.ChecksDefinition{
|
||||||
NBVersionCheck: &posture.NBVersionCheck{
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
@ -58,8 +56,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
|||||||
assert.Len(t, checks, 1)
|
assert.Len(t, checks, 1)
|
||||||
|
|
||||||
// should not be possible to create posture check with non uniq name
|
// should not be possible to create posture check with non uniq name
|
||||||
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
_, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
||||||
ID: "new-id",
|
|
||||||
Name: postureCheckName,
|
Name: postureCheckName,
|
||||||
Checks: posture.ChecksDefinition{
|
Checks: posture.ChecksDefinition{
|
||||||
GeoLocationCheck: &posture.GeoLocationCheck{
|
GeoLocationCheck: &posture.GeoLocationCheck{
|
||||||
@ -74,23 +71,20 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
|
|||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
// admins can update posture checks
|
// admins can update posture checks
|
||||||
err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
|
postureCheck.Checks = posture.ChecksDefinition{
|
||||||
ID: postureCheckID,
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
Name: postureCheckName,
|
MinVersion: "0.27.0",
|
||||||
Checks: posture.ChecksDefinition{
|
|
||||||
NBVersionCheck: &posture.NBVersionCheck{
|
|
||||||
MinVersion: "0.27.0",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
})
|
}
|
||||||
|
_, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// users should not be able to delete posture checks
|
// users should not be able to delete posture checks
|
||||||
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID)
|
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, regularUserID)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
// admin should be able to delete posture checks
|
// admin should be able to delete posture checks
|
||||||
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID)
|
err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, adminUserID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID)
|
checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@ -150,9 +144,22 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||||
})
|
})
|
||||||
|
|
||||||
postureCheck := posture.Checks{
|
postureCheckA := &posture.Checks{
|
||||||
ID: "postureCheck",
|
Name: "postureCheckA",
|
||||||
Name: "postureCheck",
|
AccountID: account.Id,
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
ProcessCheck: &posture.ProcessCheck{
|
||||||
|
Processes: []posture.Process{
|
||||||
|
{LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
postureCheckB := &posture.Checks{
|
||||||
|
Name: "postureCheckB",
|
||||||
AccountID: account.Id,
|
AccountID: account.Id,
|
||||||
Checks: posture.ChecksDefinition{
|
Checks: posture.ChecksDefinition{
|
||||||
NBVersionCheck: &posture.NBVersionCheck{
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
@ -169,7 +176,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -187,12 +194,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
postureCheck.Checks = posture.ChecksDefinition{
|
postureCheckB.Checks = posture.ChecksDefinition{
|
||||||
NBVersionCheck: &posture.NBVersionCheck{
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
MinVersion: "0.29.0",
|
MinVersion: "0.29.0",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -202,12 +209,10 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
policy := Policy{
|
policy := &Policy{
|
||||||
ID: "policyA",
|
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
ID: xid.New().String(),
|
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Sources: []string{"groupA"},
|
Sources: []string{"groupA"},
|
||||||
Destinations: []string{"groupA"},
|
Destinations: []string{"groupA"},
|
||||||
@ -215,7 +220,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
SourcePostureChecks: []string{postureCheck.ID},
|
SourcePostureChecks: []string{postureCheckB.ID},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Linking posture check to policy should trigger update account peers and send peer update
|
// Linking posture check to policy should trigger update account peers and send peer update
|
||||||
@ -226,7 +231,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -238,7 +243,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
|
|
||||||
// Updating linked posture checks should update account peers and send peer update
|
// Updating linked posture checks should update account peers and send peer update
|
||||||
t.Run("updating linked to posture check with peers", func(t *testing.T) {
|
t.Run("updating linked to posture check with peers", func(t *testing.T) {
|
||||||
postureCheck.Checks = posture.ChecksDefinition{
|
postureCheckB.Checks = posture.ChecksDefinition{
|
||||||
NBVersionCheck: &posture.NBVersionCheck{
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
MinVersion: "0.29.0",
|
MinVersion: "0.29.0",
|
||||||
},
|
},
|
||||||
@ -255,7 +260,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -274,8 +279,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
policy.SourcePostureChecks = []string{}
|
policy.SourcePostureChecks = []string{}
|
||||||
|
_, err := manager.SavePolicy(context.Background(), account.Id, userID, policy)
|
||||||
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -293,7 +297,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID)
|
err := manager.DeletePostureChecks(context.Background(), account.Id, postureCheckA.ID, userID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -303,17 +307,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
|
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
|
||||||
t.Run("updating linked posture check to policy with no peers", func(t *testing.T) {
|
t.Run("updating linked posture check to policy with no peers", func(t *testing.T) {
|
||||||
policy = Policy{
|
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
ID: "policyB",
|
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
ID: xid.New().String(),
|
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Sources: []string{"groupB"},
|
Sources: []string{"groupB"},
|
||||||
Destinations: []string{"groupC"},
|
Destinations: []string{"groupC"},
|
||||||
@ -321,9 +323,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
SourcePostureChecks: []string{postureCheck.ID},
|
SourcePostureChecks: []string{postureCheckB.ID},
|
||||||
}
|
})
|
||||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
@ -332,12 +333,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
postureCheck.Checks = posture.ChecksDefinition{
|
postureCheckB.Checks = posture.ChecksDefinition{
|
||||||
NBVersionCheck: &posture.NBVersionCheck{
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
MinVersion: "0.29.0",
|
MinVersion: "0.29.0",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -354,12 +355,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
|
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
|
||||||
})
|
})
|
||||||
policy = Policy{
|
|
||||||
ID: "policyB",
|
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
ID: xid.New().String(),
|
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Sources: []string{"groupB"},
|
Sources: []string{"groupB"},
|
||||||
Destinations: []string{"groupA"},
|
Destinations: []string{"groupA"},
|
||||||
@ -367,10 +367,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
SourcePostureChecks: []string{postureCheck.ID},
|
SourcePostureChecks: []string{postureCheckB.ID},
|
||||||
}
|
})
|
||||||
|
|
||||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
@ -379,12 +377,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
postureCheck.Checks = posture.ChecksDefinition{
|
postureCheckB.Checks = posture.ChecksDefinition{
|
||||||
NBVersionCheck: &posture.NBVersionCheck{
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
MinVersion: "0.29.0",
|
MinVersion: "0.29.0",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -397,8 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
// Updating linked client posture check to policy where source has peers but destination does not,
|
// Updating linked client posture check to policy where source has peers but destination does not,
|
||||||
// should trigger account peers update and send peer update
|
// should trigger account peers update and send peer update
|
||||||
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
|
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
|
||||||
policy = Policy{
|
_, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
|
||||||
ID: "policyB",
|
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
@ -409,9 +406,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
Action: PolicyTrafficActionAccept,
|
Action: PolicyTrafficActionAccept,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
SourcePostureChecks: []string{postureCheck.ID},
|
SourcePostureChecks: []string{postureCheckB.ID},
|
||||||
}
|
})
|
||||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
@ -420,7 +416,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
postureCheck.Checks = posture.ChecksDefinition{
|
postureCheckB.Checks = posture.ChecksDefinition{
|
||||||
ProcessCheck: &posture.ProcessCheck{
|
ProcessCheck: &posture.ProcessCheck{
|
||||||
Processes: []posture.Process{
|
Processes: []posture.Process{
|
||||||
{
|
{
|
||||||
@ -429,7 +425,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
|
_, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -440,80 +436,120 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestArePostureCheckChangesAffectingPeers(t *testing.T) {
|
func TestArePostureCheckChangesAffectPeers(t *testing.T) {
|
||||||
account := &Account{
|
manager, err := createManager(t)
|
||||||
Policies: []*Policy{
|
require.NoError(t, err, "failed to create account manager")
|
||||||
{
|
|
||||||
ID: "policyA",
|
account, err := initTestPostureChecksAccount(manager)
|
||||||
Rules: []*PolicyRule{
|
require.NoError(t, err, "failed to init testing account")
|
||||||
{
|
|
||||||
Enabled: true,
|
groupA := &group.Group{
|
||||||
Sources: []string{"groupA"},
|
ID: "groupA",
|
||||||
Destinations: []string{"groupA"},
|
AccountID: account.Id,
|
||||||
},
|
Peers: []string{"peer1"},
|
||||||
},
|
|
||||||
SourcePostureChecks: []string{"checkA"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Groups: map[string]*group.Group{
|
|
||||||
"groupA": {
|
|
||||||
ID: "groupA",
|
|
||||||
Peers: []string{"peer1"},
|
|
||||||
},
|
|
||||||
"groupB": {
|
|
||||||
ID: "groupB",
|
|
||||||
Peers: []string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
PostureChecks: []*posture.Checks{
|
|
||||||
{
|
|
||||||
ID: "checkA",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "checkB",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groupB := &group.Group{
|
||||||
|
ID: "groupB",
|
||||||
|
AccountID: account.Id,
|
||||||
|
Peers: []string{},
|
||||||
|
}
|
||||||
|
err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB})
|
||||||
|
require.NoError(t, err, "failed to save groups")
|
||||||
|
|
||||||
|
postureCheckA := &posture.Checks{
|
||||||
|
Name: "checkA",
|
||||||
|
AccountID: account.Id,
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA)
|
||||||
|
require.NoError(t, err, "failed to save postureCheckA")
|
||||||
|
|
||||||
|
postureCheckB := &posture.Checks{
|
||||||
|
Name: "checkB",
|
||||||
|
AccountID: account.Id,
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB)
|
||||||
|
require.NoError(t, err, "failed to save postureCheckB")
|
||||||
|
|
||||||
|
policy := &Policy{
|
||||||
|
AccountID: account.Id,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupA"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SourcePostureChecks: []string{postureCheckA.ID},
|
||||||
|
}
|
||||||
|
|
||||||
|
policy, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
|
||||||
|
require.NoError(t, err, "failed to save policy")
|
||||||
|
|
||||||
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
|
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
|
||||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
assert.True(t, result)
|
assert.True(t, result)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
|
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
|
||||||
result := arePostureCheckChangesAffectingPeers(account, "checkB", true)
|
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
assert.False(t, result)
|
assert.False(t, result)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("posture check does not exist", func(t *testing.T) {
|
t.Run("posture check does not exist", func(t *testing.T) {
|
||||||
result := arePostureCheckChangesAffectingPeers(account, "unknown", false)
|
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown")
|
||||||
|
require.NoError(t, err)
|
||||||
assert.False(t, result)
|
assert.False(t, result)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
|
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
|
||||||
account.Policies[0].Rules[0].Sources = []string{"groupB"}
|
policy.Rules[0].Sources = []string{"groupB"}
|
||||||
account.Policies[0].Rules[0].Destinations = []string{"groupA"}
|
policy.Rules[0].Destinations = []string{"groupA"}
|
||||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
|
||||||
|
require.NoError(t, err, "failed to update policy")
|
||||||
|
|
||||||
|
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
assert.True(t, result)
|
assert.True(t, result)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
|
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
|
||||||
account.Policies[0].Rules[0].Sources = []string{"groupA"}
|
policy.Rules[0].Sources = []string{"groupA"}
|
||||||
account.Policies[0].Rules[0].Destinations = []string{"groupB"}
|
policy.Rules[0].Destinations = []string{"groupB"}
|
||||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
|
||||||
|
require.NoError(t, err, "failed to update policy")
|
||||||
|
|
||||||
|
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
assert.True(t, result)
|
assert.True(t, result)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
|
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
|
||||||
account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"}
|
groupA.Peers = []string{}
|
||||||
account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"}
|
err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA)
|
||||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
require.NoError(t, err, "failed to save groups")
|
||||||
|
|
||||||
|
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
assert.False(t, result)
|
assert.False(t, result)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
|
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
|
||||||
account.Groups["groupA"].Peers = []string{}
|
policy.Rules[0].Sources = []string{"nonExistentGroup"}
|
||||||
result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
|
policy.Rules[0].Destinations = []string{"nonExistentGroup"}
|
||||||
|
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
|
||||||
|
require.NoError(t, err, "failed to update policy")
|
||||||
|
|
||||||
|
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
assert.False(t, result)
|
assert.False(t, result)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -237,7 +237,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if isRouteChangeAffectPeers(account, &newRoute) {
|
if am.isRouteChangeAffectPeers(account, &newRoute) {
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -323,7 +323,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
|
if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) {
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -355,7 +355,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
|||||||
|
|
||||||
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
||||||
|
|
||||||
if isRouteChangeAffectPeers(account, routy) {
|
if am.isRouteChangeAffectPeers(account, routy) {
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -651,6 +651,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
|
|||||||
|
|
||||||
// isRouteChangeAffectPeers checks if a given route affects peers by determining
|
// isRouteChangeAffectPeers checks if a given route affects peers by determining
|
||||||
// if it has a routing peer, distribution, or peer groups that include peers
|
// if it has a routing peer, distribution, or peer groups that include peers
|
||||||
func isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
|
func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
|
||||||
return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
|
return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
|
||||||
}
|
}
|
||||||
|
@ -1214,12 +1214,11 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
|
|||||||
|
|
||||||
defaultRule := rules[0]
|
defaultRule := rules[0]
|
||||||
newPolicy := defaultRule.Copy()
|
newPolicy := defaultRule.Copy()
|
||||||
newPolicy.ID = xid.New().String()
|
|
||||||
newPolicy.Name = "peer1 only"
|
newPolicy.Name = "peer1 only"
|
||||||
newPolicy.Rules[0].Sources = []string{newGroup.ID}
|
newPolicy.Rules[0].Sources = []string{newGroup.ID}
|
||||||
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
|
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
|
||||||
|
|
||||||
err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false)
|
_, err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)
|
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)
|
||||||
|
@ -406,8 +406,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
|||||||
})
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
policy := Policy{
|
policy := &Policy{
|
||||||
ID: "policy",
|
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
@ -419,7 +418,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
|
@ -1162,9 +1162,10 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
|
|||||||
First(&accountDNSSettings, idQueryCondition, accountID)
|
First(&accountDNSSettings, idQueryCondition, accountID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "dns settings not found")
|
return nil, status.NewAccountNotFoundError(accountID)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
|
log.WithContext(ctx).Errorf("failed to get dns settings from store: %v", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get dns settings from store")
|
||||||
}
|
}
|
||||||
return &accountDNSSettings.DNSSettings, nil
|
return &accountDNSSettings.DNSSettings, nil
|
||||||
}
|
}
|
||||||
@ -1243,8 +1244,8 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
|
|||||||
var groups []*nbgroup.Group
|
var groups []*nbgroup.Group
|
||||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error)
|
log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error)
|
||||||
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store")
|
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
groupsMap := make(map[string]*nbgroup.Group)
|
groupsMap := make(map[string]*nbgroup.Group)
|
||||||
@ -1295,22 +1296,139 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a
|
|||||||
|
|
||||||
// GetAccountPolicies retrieves policies for an account.
|
// GetAccountPolicies retrieves policies for an account.
|
||||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
||||||
return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID)
|
var policies []*Policy
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Preload(clause.Associations).Find(&policies, accountIDCondition, accountID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get policies from the store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get policies from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return policies, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPolicyByID retrieves a policy by its ID and account ID.
|
// GetPolicyByID retrieves a policy by its ID and account ID.
|
||||||
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
|
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) {
|
||||||
return getRecordByID[Policy](s.db.Preload(clause.Associations), lockStrength, policyID, accountID)
|
var policy *Policy
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
|
||||||
|
First(&policy, accountAndIDQueryCondition, accountID, policyID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.NewPolicyNotFoundError(policyID)
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("failed to get policy from store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get policy from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return policy, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to create policy in store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SavePolicy saves a policy to the database.
|
||||||
|
func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
|
||||||
|
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).
|
||||||
|
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to save policy to store")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error {
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete policy from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.NewPolicyNotFoundError(policyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountPostureChecks retrieves posture checks for an account.
|
// GetAccountPostureChecks retrieves posture checks for an account.
|
||||||
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
|
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
|
||||||
return getRecords[*posture.Checks](s.db, lockStrength, accountID)
|
var postureChecks []*posture.Checks
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get posture checks from store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get posture checks from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return postureChecks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
||||||
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
|
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) {
|
||||||
return getRecordByID[posture.Checks](s.db, lockStrength, postureCheckID, accountID)
|
var postureCheck *posture.Checks
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
First(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.NewPostureChecksNotFoundError(postureChecksID)
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("failed to get posture check from store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get posture check from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return postureCheck, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID.
|
||||||
|
func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) {
|
||||||
|
var postureChecks []*posture.Checks
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get posture checks by ID's from store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get posture checks by ID's from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
postureChecksMap := make(map[string]*posture.Checks)
|
||||||
|
for _, postureCheck := range postureChecks {
|
||||||
|
postureChecksMap[postureCheck.ID] = postureCheck
|
||||||
|
}
|
||||||
|
|
||||||
|
return postureChecksMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SavePostureChecks saves a posture checks to the database.
|
||||||
|
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to save posture checks to store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePostureChecks deletes a posture checks from the database.
|
||||||
|
func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error {
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete posture checks from store: %s", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete posture checks from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.NewPostureChecksNotFoundError(postureChecksID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountRoutes retrieves network routes for an account.
|
// GetAccountRoutes retrieves network routes for an account.
|
||||||
@ -1380,12 +1498,55 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren
|
|||||||
|
|
||||||
// GetAccountNameServerGroups retrieves name server groups for an account.
|
// GetAccountNameServerGroups retrieves name server groups for an account.
|
||||||
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||||
return getRecords[*nbdns.NameServerGroup](s.db, lockStrength, accountID)
|
var nsGroups []*nbdns.NameServerGroup
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get name server groups from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nsGroups, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
|
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
|
||||||
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
|
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||||
return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID)
|
var nsGroup *nbdns.NameServerGroup
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.NewNameServerGroupNotFoundError(nsGroupID)
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("failed to get name server group from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get name server group from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nsGroup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveNameServerGroup saves a name server group to the database.
|
||||||
|
func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error {
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to save name server group to store")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteNameServerGroup deletes a name server group from the database.
|
||||||
|
func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) error {
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nsGroupID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete name server group from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.NewNameServerGroupNotFoundError(nsGroupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRecords retrieves records from the database based on the account ID.
|
// getRecords retrieves records from the database based on the account ID.
|
||||||
@ -1420,3 +1581,19 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
|
|||||||
}
|
}
|
||||||
return &record, nil
|
return &record, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveDNSSettings saves the DNS settings to the store.
|
||||||
|
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error {
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||||
|
Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings})
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to save dns settings to store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.NewAccountNotFoundError(accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
@ -1564,3 +1565,489 @@ func TestSqlStore_GetPeersByIDs(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetPostureChecksByID(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
postureChecksID string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve existing posture checks",
|
||||||
|
postureChecksID: "csplshq7qv948l48f7t0",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve non-existing posture checks",
|
||||||
|
postureChecksID: "non-existing",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve with empty posture checks ID",
|
||||||
|
postureChecksID: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
postureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
require.Nil(t, postureChecks)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, postureChecks)
|
||||||
|
require.Equal(t, tt.postureChecksID, postureChecks.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetPostureChecksByIDs(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
postureCheckIDs []string
|
||||||
|
expectedCount int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve existing posture checks by existing IDs",
|
||||||
|
postureCheckIDs: []string{"csplshq7qv948l48f7t0", "cspnllq7qv95uq1r4k90"},
|
||||||
|
expectedCount: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty posture check IDs list",
|
||||||
|
postureCheckIDs: []string{},
|
||||||
|
expectedCount: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-existing posture check IDs",
|
||||||
|
postureCheckIDs: []string{"nonexistent1", "nonexistent2"},
|
||||||
|
expectedCount: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed existing and non-existing posture check IDs",
|
||||||
|
postureCheckIDs: []string{"cspnllq7qv95uq1r4k90", "nonexistent"},
|
||||||
|
expectedCount: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
groups, err := store.GetPostureChecksByIDs(context.Background(), LockingStrengthShare, accountID, tt.postureCheckIDs)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, groups, tt.expectedCount)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_SavePostureChecks(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
postureChecks := &posture.Checks{
|
||||||
|
ID: "posture-checks-id",
|
||||||
|
AccountID: accountID,
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
|
MinVersion: "0.31.0",
|
||||||
|
},
|
||||||
|
OSVersionCheck: &posture.OSVersionCheck{
|
||||||
|
Ios: &posture.MinVersionCheck{
|
||||||
|
MinVersion: "13.0.1",
|
||||||
|
},
|
||||||
|
Linux: &posture.MinKernelVersionCheck{
|
||||||
|
MinKernelVersion: "5.3.3-dev",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
GeoLocationCheck: &posture.GeoLocationCheck{
|
||||||
|
Locations: []posture.Location{
|
||||||
|
{
|
||||||
|
CountryCode: "DE",
|
||||||
|
CityName: "Berlin",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Action: posture.CheckActionAllow,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = store.SavePostureChecks(context.Background(), LockingStrengthUpdate, postureChecks)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
savePostureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, "posture-checks-id")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, savePostureChecks, postureChecks)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_DeletePostureChecks(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
postureChecksID string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "delete existing posture checks",
|
||||||
|
postureChecksID: "csplshq7qv948l48f7t0",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete non-existing posture checks",
|
||||||
|
postureChecksID: "non-existing-posture-checks-id",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "delete with empty posture checks ID",
|
||||||
|
postureChecksID: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err = store.DeletePostureChecks(context.Background(), LockingStrengthUpdate, accountID, tt.postureChecksID)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
group, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, group)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetPolicyByID(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
policyID string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve existing policy",
|
||||||
|
policyID: "cs1tnh0hhcjnqoiuebf0",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve non-existing policy checks",
|
||||||
|
policyID: "non-existing",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve with empty policy ID",
|
||||||
|
policyID: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, tt.policyID)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
require.Nil(t, policy)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, policy)
|
||||||
|
require.Equal(t, tt.policyID, policy.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_CreatePolicy(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
policy := &Policy{
|
||||||
|
ID: "policy-id",
|
||||||
|
AccountID: accountID,
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{"groupA"},
|
||||||
|
Destinations: []string{"groupC"},
|
||||||
|
Bidirectional: true,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = store.CreatePolicy(context.Background(), LockingStrengthUpdate, policy)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, savePolicy, policy)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_SavePolicy(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
policyID := "cs1tnh0hhcjnqoiuebf0"
|
||||||
|
|
||||||
|
policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
policy.Enabled = false
|
||||||
|
policy.Description = "policy"
|
||||||
|
policy.Rules[0].Sources = []string{"group"}
|
||||||
|
policy.Rules[0].Ports = []string{"80", "443"}
|
||||||
|
err = store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, savePolicy, policy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_DeletePolicy(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
policyID := "cs1tnh0hhcjnqoiuebf0"
|
||||||
|
|
||||||
|
err = store.DeletePolicy(context.Background(), LockingStrengthShare, accountID, policyID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, policy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetDNSSettings(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accountID string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve existing account dns settings",
|
||||||
|
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve non-existing account dns settings",
|
||||||
|
accountID: "non-existing",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve dns settings with empty account ID",
|
||||||
|
accountID: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, tt.accountID)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
require.Nil(t, dnsSettings)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, dnsSettings)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_SaveDNSSettings(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
dnsSettings.DisabledManagementGroups = []string{"groupA", "groupB"}
|
||||||
|
err = store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, dnsSettings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
saveDNSSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, saveDNSSettings, dnsSettings)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetAccountNameServerGroups(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accountID string
|
||||||
|
expectedCount int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve name server groups by existing account ID",
|
||||||
|
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||||
|
expectedCount: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-existing account ID",
|
||||||
|
accountID: "nonexistent",
|
||||||
|
expectedCount: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty account ID",
|
||||||
|
accountID: "",
|
||||||
|
expectedCount: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
peers, err := store.GetAccountNameServerGroups(context.Background(), LockingStrengthShare, tt.accountID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, peers, tt.expectedCount)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_GetNameServerByID(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
nsGroupID string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "retrieve existing nameserver group",
|
||||||
|
nsGroupID: "csqdelq7qv97ncu7d9t0",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve non-existing nameserver group",
|
||||||
|
nsGroupID: "non-existing",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "retrieve with empty nameserver group ID",
|
||||||
|
nsGroupID: "",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, tt.nsGroupID)
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
sErr, ok := status.FromError(err)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, sErr.Type(), status.NotFound)
|
||||||
|
require.Nil(t, nsGroup)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, nsGroup)
|
||||||
|
require.Equal(t, tt.nsGroupID, nsGroup.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_SaveNameServerGroup(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
|
||||||
|
nsGroup := &nbdns.NameServerGroup{
|
||||||
|
ID: "ns-group-id",
|
||||||
|
AccountID: accountID,
|
||||||
|
Name: "NS Group",
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
NSType: 1,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Groups: []string{"groupA"},
|
||||||
|
Primary: true,
|
||||||
|
Enabled: true,
|
||||||
|
SearchDomainsEnabled: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nsGroup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
saveNSGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, nsGroup.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, saveNSGroup, nsGroup)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSqlStore_DeleteNameServerGroup(t *testing.T) {
|
||||||
|
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||||
|
nsGroupID := "csqdelq7qv97ncu7d9t0"
|
||||||
|
|
||||||
|
err = store.DeleteNameServerGroup(context.Background(), LockingStrengthShare, accountID, nsGroupID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, nsGroupID)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Nil(t, nsGroup)
|
||||||
|
}
|
||||||
|
@ -139,3 +139,18 @@ func NewGetAccountError(err error) error {
|
|||||||
func NewGroupNotFoundError(groupID string) error {
|
func NewGroupNotFoundError(groupID string) error {
|
||||||
return Errorf(NotFound, "group: %s not found", groupID)
|
return Errorf(NotFound, "group: %s not found", groupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewPostureChecksNotFoundError creates a new Error with NotFound type for a missing posture checks
|
||||||
|
func NewPostureChecksNotFoundError(postureChecksID string) error {
|
||||||
|
return Errorf(NotFound, "posture checks: %s not found", postureChecksID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPolicyNotFoundError creates a new Error with NotFound type for a missing policy
|
||||||
|
func NewPolicyNotFoundError(policyID string) error {
|
||||||
|
return Errorf(NotFound, "policy: %s not found", policyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNameServerGroupNotFoundError creates a new Error with NotFound type for a missing name server group
|
||||||
|
func NewNameServerGroupNotFoundError(nsGroupID string) error {
|
||||||
|
return Errorf(NotFound, "nameserver group: %s not found", nsGroupID)
|
||||||
|
}
|
||||||
|
@ -59,6 +59,7 @@ type Store interface {
|
|||||||
SaveAccount(ctx context.Context, account *Account) error
|
SaveAccount(ctx context.Context, account *Account) error
|
||||||
DeleteAccount(ctx context.Context, account *Account) error
|
DeleteAccount(ctx context.Context, account *Account) error
|
||||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||||
|
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error
|
||||||
|
|
||||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||||
@ -80,11 +81,17 @@ type Store interface {
|
|||||||
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
|
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
|
||||||
|
|
||||||
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
||||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error)
|
||||||
|
CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
|
||||||
|
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
|
||||||
|
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
|
||||||
|
|
||||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
||||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
|
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
|
||||||
|
GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error)
|
||||||
|
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
||||||
|
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
||||||
|
|
||||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||||
@ -110,6 +117,8 @@ type Store interface {
|
|||||||
|
|
||||||
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
||||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
||||||
|
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
|
||||||
|
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
|
||||||
|
|
||||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||||
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
|
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
|
||||||
|
@ -34,4 +34,7 @@ INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003'
|
|||||||
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
|
||||||
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,'');
|
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,'');
|
||||||
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,'');
|
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,'');
|
||||||
|
INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}');
|
||||||
|
INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}');
|
||||||
|
INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0);
|
||||||
INSERT INTO installations VALUES(1,'');
|
INSERT INTO installations VALUES(1,'');
|
||||||
|
@ -1280,8 +1280,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
policy := Policy{
|
policy := &Policy{
|
||||||
ID: "policy",
|
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Rules: []*PolicyRule{
|
Rules: []*PolicyRule{
|
||||||
{
|
{
|
||||||
@ -1293,7 +1292,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
|
_, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||||
|
31
util/net/conn.go
Normal file
31
util/net/conn.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Conn wraps a net.Conn to override the Close method
|
||||||
|
type Conn struct {
|
||||||
|
net.Conn
|
||||||
|
ID ConnectionID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
err := c.Conn.Close()
|
||||||
|
|
||||||
|
dialerCloseHooksMutex.RLock()
|
||||||
|
defer dialerCloseHooksMutex.RUnlock()
|
||||||
|
|
||||||
|
for _, hook := range dialerCloseHooks {
|
||||||
|
if err := hook(c.ID, &c.Conn); err != nil {
|
||||||
|
log.Errorf("Error executing dialer close hook: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
58
util/net/dial.go
Normal file
58
util/net/dial.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return net.DialUDP(network, laddr, raddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := NewDialer()
|
||||||
|
dialer.LocalAddr = laddr
|
||||||
|
|
||||||
|
conn, err := dialer.Dial(network, raddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
|
||||||
|
if !ok {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return udpConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return net.DialTCP(network, laddr, raddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := NewDialer()
|
||||||
|
dialer.LocalAddr = laddr
|
||||||
|
|
||||||
|
conn, err := dialer.Dial(network, raddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
|
||||||
|
if !ok {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tcpConn, nil
|
||||||
|
}
|
@ -1,25 +0,0 @@
|
|||||||
package net
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (d *Dialer) init() {
|
|
||||||
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
|
|
||||||
err := c.Control(func(fd uintptr) {
|
|
||||||
androidProtectSocketLock.Lock()
|
|
||||||
f := androidProtectSocket
|
|
||||||
androidProtectSocketLock.Unlock()
|
|
||||||
if f == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ok := f(int32(fd))
|
|
||||||
if !ok {
|
|
||||||
log.Errorf("failed to protect socket: %d", fd)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
@ -81,28 +81,6 @@ func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
|||||||
return d.DialContext(context.Background(), network, address)
|
return d.DialContext(context.Background(), network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conn wraps a net.Conn to override the Close method
|
|
||||||
type Conn struct {
|
|
||||||
net.Conn
|
|
||||||
ID ConnectionID
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
|
||||||
func (c *Conn) Close() error {
|
|
||||||
err := c.Conn.Close()
|
|
||||||
|
|
||||||
dialerCloseHooksMutex.RLock()
|
|
||||||
defer dialerCloseHooksMutex.RUnlock()
|
|
||||||
|
|
||||||
for _, hook := range dialerCloseHooks {
|
|
||||||
if err := hook(c.ID, &c.Conn); err != nil {
|
|
||||||
log.Errorf("Error executing dialer close hook: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
|
func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
|
||||||
host, _, err := net.SplitHostPort(address)
|
host, _, err := net.SplitHostPort(address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -127,51 +105,3 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r
|
|||||||
|
|
||||||
return result.ErrorOrNil()
|
return result.ErrorOrNil()
|
||||||
}
|
}
|
||||||
|
|
||||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
|
||||||
if CustomRoutingDisabled() {
|
|
||||||
return net.DialUDP(network, laddr, raddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
dialer := NewDialer()
|
|
||||||
dialer.LocalAddr = laddr
|
|
||||||
|
|
||||||
conn, err := dialer.Dial(network, raddr.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
|
|
||||||
if !ok {
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
log.Errorf("Failed to close connection: %v", err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
return udpConn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
|
||||||
if CustomRoutingDisabled() {
|
|
||||||
return net.DialTCP(network, laddr, raddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
dialer := NewDialer()
|
|
||||||
dialer.LocalAddr = laddr
|
|
||||||
|
|
||||||
conn, err := dialer.Dial(network, raddr.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
|
|
||||||
if !ok {
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
log.Errorf("Failed to close connection: %v", err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tcpConn, nil
|
|
||||||
}
|
|
5
util/net/dialer_init_android.go
Normal file
5
util/net/dialer_init_android.go
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
func (d *Dialer) init() {
|
||||||
|
d.Dialer.Control = ControlProtectSocket
|
||||||
|
}
|
@ -7,6 +7,6 @@ import "syscall"
|
|||||||
// init configures the net.Dialer Control function to set the fwmark on the socket
|
// init configures the net.Dialer Control function to set the fwmark on the socket
|
||||||
func (d *Dialer) init() {
|
func (d *Dialer) init() {
|
||||||
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
|
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
|
||||||
return SetRawSocketMark(c)
|
return setRawSocketMark(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -3,4 +3,5 @@
|
|||||||
package net
|
package net
|
||||||
|
|
||||||
func (d *Dialer) init() {
|
func (d *Dialer) init() {
|
||||||
|
// implemented on Linux and Android only
|
||||||
}
|
}
|
29
util/net/env.go
Normal file
29
util/net/env.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
||||||
|
envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CustomRoutingDisabled() bool {
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return os.Getenv(envDisableCustomRouting) == "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
func SkipSocketMark() bool {
|
||||||
|
if skipSocketMark := os.Getenv(envSkipSocketMark); skipSocketMark == "true" {
|
||||||
|
log.Infof("%s is set to true, skipping SO_MARK", envSkipSocketMark)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
37
util/net/listen.go
Normal file
37
util/net/listen.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pion/transport/v3"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ListenUDP listens on the network address and returns a transport.UDPConn
|
||||||
|
// which includes support for write and close hooks.
|
||||||
|
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return net.ListenUDP(network, laddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("listen UDP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
packetConn := conn.(*PacketConn)
|
||||||
|
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
|
||||||
|
if !ok {
|
||||||
|
if err := packetConn.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
|
||||||
|
}
|
@ -1,26 +0,0 @@
|
|||||||
package net
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
|
||||||
func (l *ListenerConfig) init() {
|
|
||||||
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
|
|
||||||
err := c.Control(func(fd uintptr) {
|
|
||||||
androidProtectSocketLock.Lock()
|
|
||||||
f := androidProtectSocket
|
|
||||||
androidProtectSocketLock.Unlock()
|
|
||||||
if f == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ok := f(int32(fd))
|
|
||||||
if !ok {
|
|
||||||
log.Errorf("failed to protect listener socket: %d", fd)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
6
util/net/listener_init_android.go
Normal file
6
util/net/listener_init_android.go
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
||||||
|
func (l *ListenerConfig) init() {
|
||||||
|
l.ListenConfig.Control = ControlProtectSocket
|
||||||
|
}
|
@ -9,6 +9,6 @@ import (
|
|||||||
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
||||||
func (l *ListenerConfig) init() {
|
func (l *ListenerConfig) init() {
|
||||||
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
|
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
|
||||||
return SetRawSocketMark(c)
|
return setRawSocketMark(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -3,4 +3,5 @@
|
|||||||
package net
|
package net
|
||||||
|
|
||||||
func (l *ListenerConfig) init() {
|
func (l *ListenerConfig) init() {
|
||||||
|
// implemented on Linux and Android only
|
||||||
}
|
}
|
@ -8,7 +8,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/pion/transport/v3"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -146,27 +145,3 @@ func closeConn(id ConnectionID, conn net.PacketConn) error {
|
|||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenUDP listens on the network address and returns a transport.UDPConn
|
|
||||||
// which includes support for write and close hooks.
|
|
||||||
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
|
|
||||||
if CustomRoutingDisabled() {
|
|
||||||
return net.ListenUDP(network, laddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("listen UDP: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
packetConn := conn.(*PacketConn)
|
|
||||||
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
|
|
||||||
if !ok {
|
|
||||||
if err := packetConn.Close(); err != nil {
|
|
||||||
log.Errorf("Failed to close connection: %v", err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
|
|
||||||
}
|
|
@ -2,9 +2,6 @@ package net
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
@ -16,8 +13,6 @@ const (
|
|||||||
PreroutingFwmarkRedirected = 0x1BD01
|
PreroutingFwmarkRedirected = 0x1BD01
|
||||||
PreroutingFwmarkMasquerade = 0x1BD11
|
PreroutingFwmarkMasquerade = 0x1BD11
|
||||||
PreroutingFwmarkMasqueradeReturn = 0x1BD12
|
PreroutingFwmarkMasqueradeReturn = 0x1BD12
|
||||||
|
|
||||||
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionID provides a globally unique identifier for network connections.
|
// ConnectionID provides a globally unique identifier for network connections.
|
||||||
@ -31,10 +26,3 @@ type RemoveHookFunc func(connID ConnectionID) error
|
|||||||
func GenerateConnID() ConnectionID {
|
func GenerateConnID() ConnectionID {
|
||||||
return ConnectionID(uuid.NewString())
|
return ConnectionID(uuid.NewString())
|
||||||
}
|
}
|
||||||
|
|
||||||
func CustomRoutingDisabled() bool {
|
|
||||||
if netstack.IsEnabled() {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return os.Getenv(envDisableCustomRouting) == "true"
|
|
||||||
}
|
|
||||||
|
@ -4,29 +4,42 @@ package net
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK"
|
|
||||||
|
|
||||||
// SetSocketMark sets the SO_MARK option on the given socket connection
|
// SetSocketMark sets the SO_MARK option on the given socket connection
|
||||||
func SetSocketMark(conn syscall.Conn) error {
|
func SetSocketMark(conn syscall.Conn) error {
|
||||||
|
if isSocketMarkDisabled() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
sysconn, err := conn.SyscallConn()
|
sysconn, err := conn.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get raw conn: %w", err)
|
return fmt.Errorf("get raw conn: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return SetRawSocketMark(sysconn)
|
return setRawSocketMark(sysconn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetRawSocketMark(conn syscall.RawConn) error {
|
// SetSocketOpt sets the SO_MARK option on the given file descriptor
|
||||||
|
func SetSocketOpt(fd int) error {
|
||||||
|
if isSocketMarkDisabled() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return setSocketOptInt(fd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setRawSocketMark(conn syscall.RawConn) error {
|
||||||
var setErr error
|
var setErr error
|
||||||
|
|
||||||
err := conn.Control(func(fd uintptr) {
|
err := conn.Control(func(fd uintptr) {
|
||||||
setErr = SetSocketOpt(int(fd))
|
if isSocketMarkDisabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
setErr = setSocketOptInt(int(fd))
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("control: %w", err)
|
return fmt.Errorf("control: %w", err)
|
||||||
@ -39,17 +52,18 @@ func SetRawSocketMark(conn syscall.RawConn) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetSocketOpt(fd int) error {
|
func setSocketOptInt(fd int) error {
|
||||||
if CustomRoutingDisabled() {
|
|
||||||
log.Infof("Custom routing is disabled, skipping SO_MARK")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for the new environment variable
|
|
||||||
if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" {
|
|
||||||
log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
|
return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isSocketMarkDisabled() bool {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
log.Infof("Custom routing is disabled, skipping SO_MARK")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if SkipSocketMark() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
@ -1,14 +1,42 @@
|
|||||||
package net
|
package net
|
||||||
|
|
||||||
import "sync"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
androidProtectSocketLock sync.Mutex
|
androidProtectSocketLock sync.Mutex
|
||||||
androidProtectSocket func(fd int32) bool
|
androidProtectSocket func(fd int32) bool
|
||||||
)
|
)
|
||||||
|
|
||||||
func SetAndroidProtectSocketFn(f func(fd int32) bool) {
|
func SetAndroidProtectSocketFn(fn func(fd int32) bool) {
|
||||||
androidProtectSocketLock.Lock()
|
androidProtectSocketLock.Lock()
|
||||||
androidProtectSocket = f
|
androidProtectSocket = fn
|
||||||
androidProtectSocketLock.Unlock()
|
androidProtectSocketLock.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ControlProtectSocket is a Control function that sets the fwmark on the socket
|
||||||
|
func ControlProtectSocket(_, _ string, c syscall.RawConn) error {
|
||||||
|
var aErr error
|
||||||
|
err := c.Control(func(fd uintptr) {
|
||||||
|
androidProtectSocketLock.Lock()
|
||||||
|
defer androidProtectSocketLock.Unlock()
|
||||||
|
|
||||||
|
if androidProtectSocket == nil {
|
||||||
|
aErr = fmt.Errorf("socket protection function not set")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !androidProtectSocket(int32(fd)) {
|
||||||
|
aErr = fmt.Errorf("failed to protect socket via Android")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return aErr
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user