mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-15 09:42:47 +02:00
Merge branch 'groups-get-account-refactoring' into posturechecks-get-account-refactoring
This commit is contained in:
@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// persist early to ensure cleanup of chains
|
// persist early to ensure cleanup of chains
|
||||||
|
go func() {
|
||||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||||
log.Errorf("failed to persist state: %v", err)
|
log.Errorf("failed to persist state: %v", err)
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// persist early
|
// persist early
|
||||||
|
go func() {
|
||||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||||
log.Errorf("failed to persist state: %v", err)
|
log.Errorf("failed to persist state: %v", err)
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
|
err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
|
||||||
return cfg, err
|
return cfg, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
|
|||||||
|
|
||||||
// WriteOutConfig write put the prepared config to the given path
|
// WriteOutConfig write put the prepared config to the given path
|
||||||
func WriteOutConfig(path string, config *Config) error {
|
func WriteOutConfig(path string, config *Config) error {
|
||||||
return util.WriteJson(path, config)
|
return util.WriteJson(context.Background(), path, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||||
@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if updated {
|
if updated {
|
||||||
if err := util.WriteJson(input.ConfigPath, config); err != nil {
|
if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -326,9 +326,13 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
// persist dns state right away
|
// persist dns state right away
|
||||||
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
// don't block
|
||||||
|
go func() {
|
||||||
if err := s.stateManager.PersistState(ctx); err != nil {
|
if err := s.stateManager.PersistState(ctx); err != nil {
|
||||||
log.Errorf("Failed to persist dns state: %v", err)
|
log.Errorf("Failed to persist dns state: %v", err)
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if s.searchDomainNotifier != nil {
|
if s.searchDomainNotifier != nil {
|
||||||
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@ -38,7 +39,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
|
||||||
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
@ -641,6 +641,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||||
|
if e.wgInterface == nil {
|
||||||
|
return errors.New("wireguard interface is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
if e.wgInterface.Address().String() != conf.Address {
|
if e.wgInterface.Address().String() != conf.Address {
|
||||||
oldAddr := e.wgInterface.Address().String()
|
oldAddr := e.wgInterface.Address().String()
|
||||||
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
|
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
|
||||||
@ -1481,6 +1485,17 @@ func (e *Engine) stopDNSServer() {
|
|||||||
|
|
||||||
// isChecksEqual checks if two slices of checks are equal.
|
// isChecksEqual checks if two slices of checks are equal.
|
||||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||||
|
for _, check := range checks {
|
||||||
|
sort.Slice(check.Files, func(i, j int) bool {
|
||||||
|
return check.Files[i] < check.Files[j]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, oCheck := range oChecks {
|
||||||
|
sort.Slice(oCheck.Files, func(i, j int) bool {
|
||||||
|
return oCheck.Files[i] < oCheck.Files[j]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
|
||||||
return slices.Equal(checks.Files, oChecks.Files)
|
return slices.Equal(checks.Files, oChecks.Files)
|
||||||
})
|
})
|
||||||
|
@ -1006,6 +1006,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_CheckFilesEqual(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
inputChecks1 []*mgmtProto.Checks
|
||||||
|
inputChecks2 []*mgmtProto.Checks
|
||||||
|
expectedBool bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Equal Files In Equal Order Should Return True",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Equal Files In Reverse Order Should Return True",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile2",
|
||||||
|
"testfile1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Unequal Files Should Return False",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile3",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Compared With Empty Should Return False",
|
||||||
|
inputChecks1: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{
|
||||||
|
"testfile1",
|
||||||
|
"testfile2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
inputChecks2: []*mgmtProto.Checks{
|
||||||
|
{
|
||||||
|
Files: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedBool: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2)
|
||||||
|
assert.Equal(t, testCase.expectedBool, result, "result should match expected bool")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -227,6 +228,64 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
currentRoute: "route1",
|
currentRoute: "route1",
|
||||||
expectedRouteID: "route1",
|
expectedRouteID: "route1",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "relayed routes with latency 0 should maintain previous choice",
|
||||||
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
|
"route1": {
|
||||||
|
connected: true,
|
||||||
|
relayed: true,
|
||||||
|
latency: 0 * time.Millisecond,
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
connected: true,
|
||||||
|
relayed: true,
|
||||||
|
latency: 0 * time.Millisecond,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
|
"route1": {
|
||||||
|
ID: "route1",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer1",
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
ID: "route2",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
currentRoute: "route1",
|
||||||
|
expectedRouteID: "route1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "p2p routes with latency 0 should maintain previous choice",
|
||||||
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
|
"route1": {
|
||||||
|
connected: true,
|
||||||
|
relayed: false,
|
||||||
|
latency: 0 * time.Millisecond,
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
connected: true,
|
||||||
|
relayed: false,
|
||||||
|
latency: 0 * time.Millisecond,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
existingRoutes: map[route.ID]*route.Route{
|
||||||
|
"route1": {
|
||||||
|
ID: "route1",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer1",
|
||||||
|
},
|
||||||
|
"route2": {
|
||||||
|
ID: "route2",
|
||||||
|
Metric: route.MaxMetric,
|
||||||
|
Peer: "peer2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
currentRoute: "route1",
|
||||||
|
expectedRouteID: "route1",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "current route with bad score should be changed to route with better score",
|
name: "current route with bad score should be changed to route with better score",
|
||||||
statuses: map[route.ID]routerPeerStatus{
|
statuses: map[route.ID]routerPeerStatus{
|
||||||
@ -287,6 +346,45 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fill the test data with random routes
|
||||||
|
for _, tc := range testCases {
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
dummyRoute := &route.Route{
|
||||||
|
ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
|
||||||
|
Metric: route.MinMetric,
|
||||||
|
Peer: fmt.Sprintf("dummy_p1_%d", i),
|
||||||
|
}
|
||||||
|
tc.existingRoutes[dummyRoute.ID] = dummyRoute
|
||||||
|
}
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
dummyRoute := &route.Route{
|
||||||
|
ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
|
||||||
|
Metric: route.MinMetric,
|
||||||
|
Peer: fmt.Sprintf("dummy_p1_%d", i),
|
||||||
|
}
|
||||||
|
tc.existingRoutes[dummyRoute.ID] = dummyRoute
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
id := route.ID(fmt.Sprintf("dummy_p1_%d", i))
|
||||||
|
dummyStatus := routerPeerStatus{
|
||||||
|
connected: false,
|
||||||
|
relayed: true,
|
||||||
|
latency: 0,
|
||||||
|
}
|
||||||
|
tc.statuses[id] = dummyStatus
|
||||||
|
}
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
id := route.ID(fmt.Sprintf("dummy_p2_%d", i))
|
||||||
|
dummyStatus := routerPeerStatus{
|
||||||
|
connected: false,
|
||||||
|
relayed: true,
|
||||||
|
latency: 0,
|
||||||
|
}
|
||||||
|
tc.statuses[id] = dummyStatus
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
currentRoute := &route.Route{
|
currentRoute := &route.Route{
|
||||||
|
@ -16,6 +16,7 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// State interface defines the methods that all state types must implement
|
// State interface defines the methods that all state types must implement
|
||||||
@ -178,25 +179,14 @@ func (m *Manager) PersistState(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
done := make(chan error, 1)
|
done := make(chan error, 1)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
go func() {
|
go func() {
|
||||||
data, err := json.MarshalIndent(m.states, "", " ")
|
done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states)
|
||||||
if err != nil {
|
|
||||||
done <- fmt.Errorf("marshal states: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// nolint:gosec
|
|
||||||
if err := os.WriteFile(m.filePath, data, 0640); err != nil {
|
|
||||||
done <- fmt.Errorf("write state file: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
done <- nil
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@ -208,7 +198,7 @@ func (m *Manager) PersistState(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty))
|
log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start))
|
||||||
|
|
||||||
clear(m.dirty)
|
clear(m.dirty)
|
||||||
|
|
||||||
|
@ -4,32 +4,20 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetDefaultStatePath returns the path to the state file based on the operating system
|
// GetDefaultStatePath returns the path to the state file based on the operating system
|
||||||
// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist.
|
// It returns an empty string if the path cannot be determined.
|
||||||
func GetDefaultStatePath() string {
|
func GetDefaultStatePath() string {
|
||||||
var path string
|
|
||||||
|
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "windows":
|
case "windows":
|
||||||
path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
|
return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
|
||||||
case "darwin", "linux":
|
case "darwin", "linux":
|
||||||
path = "/var/lib/netbird/state.json"
|
return "/var/lib/netbird/state.json"
|
||||||
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||||
path = "/var/db/netbird/state.json"
|
return "/var/db/netbird/state.json"
|
||||||
// ios/android don't need state
|
|
||||||
default:
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dir := filepath.Dir(path)
|
|
||||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
||||||
log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
|
|
||||||
return ""
|
return ""
|
||||||
}
|
|
||||||
|
|
||||||
return path
|
|
||||||
}
|
}
|
||||||
|
@ -1186,6 +1186,15 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
|
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||||
|
|
||||||
|
if newSettings.PeerInactivityExpirationEnabled {
|
||||||
|
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||||
|
oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration
|
||||||
|
|
||||||
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
||||||
|
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
||||||
event := activity.AccountPeerInactivityExpirationEnabled
|
event := activity.AccountPeerInactivityExpirationEnabled
|
||||||
if !newSettings.PeerInactivityExpirationEnabled {
|
if !newSettings.PeerInactivityExpirationEnabled {
|
||||||
@ -1196,10 +1205,6 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.
|
|||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
|
||||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
|
||||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -2323,7 +2328,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
|
|||||||
|
|
||||||
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
|
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -2339,6 +2344,9 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
|
|||||||
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
|
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
|
||||||
|
defer unlockPeer()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
|
|||||||
// It is recommended to call it with locking FileStore.mux
|
// It is recommended to call it with locking FileStore.mux
|
||||||
func (s *FileStore) persist(ctx context.Context, file string) error {
|
func (s *FileStore) persist(ctx context.Context, file string) error {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
err := util.WriteJson(file, s)
|
err := util.WriteJson(context.Background(), file, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -6,11 +6,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
@ -27,11 +28,6 @@ func (e *GroupLinkError) Error() string {
|
|||||||
|
|
||||||
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
|
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
|
||||||
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
|
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -41,7 +37,7 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco
|
|||||||
return status.NewUserNotPartOfAccountError()
|
return status.NewUserNotPartOfAccountError()
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.IsRegularUser() && settings.RegularUsersViewBlocked {
|
if user.IsRegularUser() {
|
||||||
return status.NewAdminPermissionError()
|
return status.NewAdminPermissionError()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,48 +211,9 @@ func difference(a, b []string) []string {
|
|||||||
|
|
||||||
// DeleteGroup object of the peers.
|
// DeleteGroup object of the peers.
|
||||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
|
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
|
||||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
if err != nil {
|
defer unlock()
|
||||||
return err
|
return am.DeleteGroups(ctx, accountID, userID, []string{groupID})
|
||||||
}
|
|
||||||
|
|
||||||
if user.AccountID != accountID {
|
|
||||||
return status.NewUserNotPartOfAccountError()
|
|
||||||
}
|
|
||||||
|
|
||||||
if user.IsRegularUser() {
|
|
||||||
return status.NewAdminPermissionError()
|
|
||||||
}
|
|
||||||
|
|
||||||
var group *nbgroup.Group
|
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
|
||||||
group, err = transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if group.IsGroupAll() {
|
|
||||||
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = validateDeleteGroup(ctx, transaction, group, userID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta())
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteGroups deletes groups from an account.
|
// DeleteGroups deletes groups from an account.
|
||||||
@ -285,13 +242,14 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
|||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
for _, groupID := range groupIDs {
|
for _, groupID := range groupIDs {
|
||||||
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
allErrors = errors.Join(allErrors, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
|
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
|
||||||
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
|
allErrors = errors.Join(allErrors, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -318,12 +276,15 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
|||||||
|
|
||||||
// GroupAddPeer appends peer to the group
|
// GroupAddPeer appends peer to the group
|
||||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||||
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
var group *nbgroup.Group
|
var group *nbgroup.Group
|
||||||
var updateAccountPeers bool
|
var updateAccountPeers bool
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -356,12 +317,15 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
|||||||
|
|
||||||
// GroupDeletePeer removes peer from the group
|
// GroupDeletePeer removes peer from the group
|
||||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||||
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
var group *nbgroup.Group
|
var group *nbgroup.Group
|
||||||
var updateAccountPeers bool
|
var updateAccountPeers bool
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
|
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -430,13 +394,17 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.
|
|||||||
if group.Issued == nbgroup.GroupIssuedIntegration {
|
if group.Issued == nbgroup.GroupIssuedIntegration {
|
||||||
executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(status.NotFound, "user not found")
|
return err
|
||||||
}
|
}
|
||||||
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
|
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
|
||||||
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
|
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if group.IsGroupAll() {
|
||||||
|
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||||
|
}
|
||||||
|
|
||||||
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
|
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||||
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
||||||
}
|
}
|
||||||
|
@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "delete non-existent group",
|
name: "delete non-existent group",
|
||||||
groupIDs: []string{"non-existent-group"},
|
groupIDs: []string{"non-existent-group"},
|
||||||
expectedDeleted: []string{"non-existent-group"},
|
expectedReasons: []string{"group: non-existent-group not found"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "delete multiple groups with mixed results",
|
name: "delete multiple groups with mixed results",
|
||||||
|
@ -521,19 +521,6 @@ components:
|
|||||||
SetupKeyRequest:
|
SetupKeyRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
name:
|
|
||||||
description: Setup Key name
|
|
||||||
type: string
|
|
||||||
example: Default key
|
|
||||||
type:
|
|
||||||
description: Setup key type, one-off for single time usage and reusable
|
|
||||||
type: string
|
|
||||||
example: reusable
|
|
||||||
expires_in:
|
|
||||||
description: Expiration time in seconds, 0 will mean the key never expires
|
|
||||||
type: integer
|
|
||||||
minimum: 0
|
|
||||||
example: 86400
|
|
||||||
revoked:
|
revoked:
|
||||||
description: Setup key revocation status
|
description: Setup key revocation status
|
||||||
type: boolean
|
type: boolean
|
||||||
@ -544,21 +531,9 @@ components:
|
|||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
example: "ch8i4ug6lnn4g9hqv7m0"
|
example: "ch8i4ug6lnn4g9hqv7m0"
|
||||||
usage_limit:
|
|
||||||
description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
|
||||||
type: integer
|
|
||||||
example: 0
|
|
||||||
ephemeral:
|
|
||||||
description: Indicate that the peer will be ephemeral or not
|
|
||||||
type: boolean
|
|
||||||
example: true
|
|
||||||
required:
|
required:
|
||||||
- name
|
|
||||||
- type
|
|
||||||
- expires_in
|
|
||||||
- revoked
|
- revoked
|
||||||
- auto_groups
|
- auto_groups
|
||||||
- usage_limit
|
|
||||||
CreateSetupKeyRequest:
|
CreateSetupKeyRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
@ -1098,23 +1098,8 @@ type SetupKeyRequest struct {
|
|||||||
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
// AutoGroups List of group IDs to auto-assign to peers registered with this key
|
||||||
AutoGroups []string `json:"auto_groups"`
|
AutoGroups []string `json:"auto_groups"`
|
||||||
|
|
||||||
// Ephemeral Indicate that the peer will be ephemeral or not
|
|
||||||
Ephemeral *bool `json:"ephemeral,omitempty"`
|
|
||||||
|
|
||||||
// ExpiresIn Expiration time in seconds, 0 will mean the key never expires
|
|
||||||
ExpiresIn int `json:"expires_in"`
|
|
||||||
|
|
||||||
// Name Setup Key name
|
|
||||||
Name string `json:"name"`
|
|
||||||
|
|
||||||
// Revoked Setup key revocation status
|
// Revoked Setup key revocation status
|
||||||
Revoked bool `json:"revoked"`
|
Revoked bool `json:"revoked"`
|
||||||
|
|
||||||
// Type Setup key type, one-off for single time usage and reusable
|
|
||||||
Type string `json:"type"`
|
|
||||||
|
|
||||||
// UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
|
|
||||||
UsageLimit int `json:"usage_limit"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// User defines model for User.
|
// User defines model for User.
|
||||||
|
@ -184,14 +184,26 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
dnsDomain := h.accountManager.GetDNSDomain()
|
dnsDomain := h.accountManager.GetDNSDomain()
|
||||||
|
|
||||||
respBody := make([]*api.PeerBatch, 0, len(account.Peers))
|
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
|
||||||
for _, peer := range account.Peers {
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), err, w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
groupsMap := map[string]*nbgroup.Group{}
|
||||||
|
groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||||
|
for _, group := range groups {
|
||||||
|
groupsMap[group.ID] = group
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody := make([]*api.PeerBatch, 0, len(peers))
|
||||||
|
for _, peer := range peers {
|
||||||
peerToReturn, err := h.checkPeerStatus(peer)
|
peerToReturn, err := h.checkPeerStatus(peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
groupMinimumInfo := toGroupsInfo(groupsMap, peer.ID)
|
||||||
|
|
||||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
|
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
|
||||||
}
|
}
|
||||||
@ -304,7 +316,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
|
|||||||
}
|
}
|
||||||
|
|
||||||
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
|
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
|
||||||
var groupsInfo []api.GroupMinimum
|
groupsInfo := []api.GroupMinimum{}
|
||||||
groupsChecked := make(map[string]struct{})
|
groupsChecked := make(map[string]struct{})
|
||||||
for _, group := range groups {
|
for _, group := range groups {
|
||||||
_, ok := groupsChecked[group.ID]
|
_, ok := groupsChecked[group.ID]
|
||||||
|
@ -137,11 +137,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Name == "" {
|
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.AutoGroups == nil {
|
if req.AutoGroups == nil {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
|
||||||
return
|
return
|
||||||
@ -150,7 +145,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
newKey := &server.SetupKey{}
|
newKey := &server.SetupKey{}
|
||||||
newKey.AutoGroups = req.AutoGroups
|
newKey.AutoGroups = req.AutoGroups
|
||||||
newKey.Revoked = req.Revoked
|
newKey.Revoked = req.Revoked
|
||||||
newKey.Name = req.Name
|
|
||||||
newKey.Id = keyID
|
newKey.Id = keyID
|
||||||
|
|
||||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
|
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
|
||||||
|
@ -168,6 +168,8 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
|
|||||||
|
|
||||||
account.UpdatePeer(peer)
|
account.UpdatePeer(peer)
|
||||||
|
|
||||||
|
log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)
|
||||||
|
|
||||||
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to save peer status: %w", err)
|
return false, fmt.Errorf("failed to save peer status: %w", err)
|
||||||
@ -669,6 +671,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
|||||||
|
|
||||||
updated := peer.UpdateMetaIfNew(sync.Meta)
|
updated := peer.UpdateMetaIfNew(sync.Meta)
|
||||||
if updated {
|
if updated {
|
||||||
|
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
|
||||||
err = am.Store.SavePeer(ctx, account.Id, peer)
|
err = am.Store.SavePeer(ctx, account.Id, peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err)
|
return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err)
|
||||||
|
@ -12,9 +12,10 @@ import (
|
|||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -276,7 +277,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
|||||||
// SaveSetupKey saves the provided SetupKey to the database overriding the existing one.
|
// SaveSetupKey saves the provided SetupKey to the database overriding the existing one.
|
||||||
// Due to the unique nature of a SetupKey certain properties must not be overwritten
|
// Due to the unique nature of a SetupKey certain properties must not be overwritten
|
||||||
// (e.g. the key itself, creation date, ID, etc).
|
// (e.g. the key itself, creation date, ID, etc).
|
||||||
// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
|
// These properties are overwritten: AutoGroups, Revoked (only from false to true), and the UpdatedAt. The rest is copied from the existing key.
|
||||||
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
|
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
|
||||||
if keyToSave == nil {
|
if keyToSave == nil {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
|
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
|
||||||
@ -312,9 +313,12 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// only auto groups, revoked status, and name can be updated for now
|
if oldKey.Revoked && !keyToSave.Revoked {
|
||||||
|
return status.Errorf(status.InvalidArgument, "can't un-revoke a revoked setup key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// only auto groups, revoked status (from false to true) can be updated
|
||||||
newKey = oldKey.Copy()
|
newKey = oldKey.Copy()
|
||||||
newKey.Name = keyToSave.Name
|
|
||||||
newKey.AutoGroups = keyToSave.AutoGroups
|
newKey.AutoGroups = keyToSave.AutoGroups
|
||||||
newKey.Revoked = keyToSave.Revoked
|
newKey.Revoked = keyToSave.Revoked
|
||||||
newKey.UpdatedAt = time.Now().UTC()
|
newKey.UpdatedAt = time.Now().UTC()
|
||||||
|
@ -56,11 +56,9 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
autoGroups := []string{"group_1", "group_2"}
|
autoGroups := []string{"group_1", "group_2"}
|
||||||
newKeyName := "my-new-test-key"
|
|
||||||
revoked := true
|
revoked := true
|
||||||
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
|
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
|
||||||
Id: key.Id,
|
Id: key.Id,
|
||||||
Name: newKeyName,
|
|
||||||
Revoked: revoked,
|
Revoked: revoked,
|
||||||
AutoGroups: autoGroups,
|
AutoGroups: autoGroups,
|
||||||
}, userID)
|
}, userID)
|
||||||
@ -68,7 +66,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
|
assertKey(t, newKey, keyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
|
||||||
key.Id, time.Now().UTC(), autoGroups, true)
|
key.Id, time.Now().UTC(), autoGroups, true)
|
||||||
|
|
||||||
// check the corresponding events that should have been generated
|
// check the corresponding events that should have been generated
|
||||||
@ -76,7 +74,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
|||||||
|
|
||||||
assert.NotNil(t, ev)
|
assert.NotNil(t, ev)
|
||||||
assert.Equal(t, account.Id, ev.AccountID)
|
assert.Equal(t, account.Id, ev.AccountID)
|
||||||
assert.Equal(t, newKeyName, ev.Meta["name"])
|
assert.Equal(t, keyName, ev.Meta["name"])
|
||||||
assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"]))
|
assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"]))
|
||||||
assert.NotEmpty(t, ev.Meta["key"])
|
assert.NotEmpty(t, ev.Meta["key"])
|
||||||
assert.Equal(t, userID, ev.InitiatorID)
|
assert.Equal(t, userID, ev.InitiatorID)
|
||||||
@ -89,7 +87,6 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
|||||||
autoGroups = append(autoGroups, groupAll.ID)
|
autoGroups = append(autoGroups, groupAll.ID)
|
||||||
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
|
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
|
||||||
Id: key.Id,
|
Id: key.Id,
|
||||||
Name: newKeyName,
|
|
||||||
Revoked: revoked,
|
Revoked: revoked,
|
||||||
AutoGroups: autoGroups,
|
AutoGroups: autoGroups,
|
||||||
}, userID)
|
}, userID)
|
||||||
@ -449,3 +446,31 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) {
|
||||||
|
manager, err := createManager(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userID := "testingUser"
|
||||||
|
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// revoke the key
|
||||||
|
updateKey := key.Copy()
|
||||||
|
updateKey.Revoked = true
|
||||||
|
_, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// re-activate revoked key
|
||||||
|
updateKey.Revoked = false
|
||||||
|
_, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID)
|
||||||
|
assert.Error(t, err, "should not allow to update revoked key")
|
||||||
|
|
||||||
|
}
|
||||||
|
@ -1278,7 +1278,7 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a
|
|||||||
Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
|
Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
|
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
|
||||||
return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error)
|
return status.Errorf(status.Internal, "failed to delete groups from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -16,6 +16,8 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
bufferSize = 8820
|
bufferSize = 8820
|
||||||
|
|
||||||
|
errCloseConn = "failed to close connection to peer: %s"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Peer represents a peer connection
|
// Peer represents a peer connection
|
||||||
@ -46,6 +48,12 @@ func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *
|
|||||||
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
|
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
|
||||||
// the message accordingly.
|
// the message accordingly.
|
||||||
func (p *Peer) Work() {
|
func (p *Peer) Work() {
|
||||||
|
defer func() {
|
||||||
|
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
|
p.log.Errorf(errCloseConn, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@ -97,7 +105,7 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
|
|||||||
case messages.MsgTypeClose:
|
case messages.MsgTypeClose:
|
||||||
p.log.Infof("peer exited gracefully")
|
p.log.Infof("peer exited gracefully")
|
||||||
if err := p.conn.Close(); err != nil {
|
if err := p.conn.Close(); err != nil {
|
||||||
log.Errorf("failed to close connection to peer: %s", err)
|
log.Errorf(errCloseConn, err)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
p.log.Warnf("received unexpected message type: %s", msgType)
|
p.log.Warnf("received unexpected message type: %s", msgType)
|
||||||
@ -121,9 +129,8 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
|
|||||||
p.log.Errorf("failed to send close message to peer: %s", p.String())
|
p.log.Errorf("failed to send close message to peer: %s", p.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.conn.Close()
|
if err := p.conn.Close(); err != nil {
|
||||||
if err != nil {
|
p.log.Errorf(errCloseConn, err)
|
||||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -132,7 +139,7 @@ func (p *Peer) Close() {
|
|||||||
defer p.connMu.Unlock()
|
defer p.connMu.Unlock()
|
||||||
|
|
||||||
if err := p.conn.Close(); err != nil {
|
if err := p.conn.Close(); err != nil {
|
||||||
p.log.Errorf("failed to close connection to peer: %s", err)
|
p.log.Errorf(errCloseConn, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
23
util/file.go
23
util/file.go
@ -15,7 +15,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
|
// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
|
||||||
func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
|
func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error {
|
||||||
configDir, configFileName, err := prepareConfigFileDir(file)
|
configDir, configFileName, err := prepareConfigFileDir(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -26,18 +26,18 @@ func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return writeJson(file, obj, configDir, configFileName)
|
return writeJson(ctx, file, obj, configDir, configFileName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteJson writes JSON config object to a file creating parent directories if required
|
// WriteJson writes JSON config object to a file creating parent directories if required
|
||||||
// The output JSON is pretty-formatted
|
// The output JSON is pretty-formatted
|
||||||
func WriteJson(file string, obj interface{}) error {
|
func WriteJson(ctx context.Context, file string, obj interface{}) error {
|
||||||
configDir, configFileName, err := prepareConfigFileDir(file)
|
configDir, configFileName, err := prepareConfigFileDir(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return writeJson(file, obj, configDir, configFileName)
|
return writeJson(ctx, file, obj, configDir, configFileName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
|
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
|
||||||
@ -79,7 +79,11 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeJson(file string, obj interface{}, configDir string, configFileName string) error {
|
func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error {
|
||||||
|
// Check context before expensive operations
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
// make it pretty
|
// make it pretty
|
||||||
bs, err := json.MarshalIndent(obj, "", " ")
|
bs, err := json.MarshalIndent(obj, "", " ")
|
||||||
@ -87,6 +91,10 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
|
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -111,6 +119,11 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check context again
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
err = os.Rename(tempFileName, file)
|
err = os.Rename(tempFileName, file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"io"
|
"io"
|
||||||
@ -39,7 +40,7 @@ func TestConfigJSON(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
err := WriteJson(tmpDir+"/testconfig.json", tt.config)
|
err := WriteJson(context.Background(), tmpDir+"/testconfig.json", tt.config)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
|
read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
|
||||||
@ -73,7 +74,7 @@ func TestCopyFileContents(t *testing.T) {
|
|||||||
src := tmpDir + "/copytest_src"
|
src := tmpDir + "/copytest_src"
|
||||||
dst := tmpDir + "/copytest_dst"
|
dst := tmpDir + "/copytest_dst"
|
||||||
|
|
||||||
err := WriteJson(src, tt.srcContent)
|
err := WriteJson(context.Background(), src, tt.srcContent)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = CopyFileContents(src, dst)
|
err = CopyFileContents(src, dst)
|
||||||
@ -127,7 +128,7 @@ func TestHandleConfigFileWithoutFullPath(t *testing.T) {
|
|||||||
_ = os.Remove(cfgFile)
|
_ = os.Remove(cfgFile)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := WriteJson(cfgFile, tt.config)
|
err := WriteJson(context.Background(), cfgFile, tt.config)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
read, err := ReadJson(cfgFile, &TestConfig{})
|
read, err := ReadJson(cfgFile, &TestConfig{})
|
||||||
|
Reference in New Issue
Block a user