Merge branch 'main' into groups-get-account-refactoring

# Conflicts:
#	management/server/group.go
This commit is contained in:
bcmmbaga 2024-11-15 13:34:59 +03:00
commit 92b9e11d3f
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
14 changed files with 188 additions and 72 deletions

View File

@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
} }
// persist early to ensure cleanup of chains // persist early to ensure cleanup of chains
go func() {
if err := stateManager.PersistState(context.Background()); err != nil { if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err) log.Errorf("failed to persist state: %v", err)
} }
}()
return nil return nil
} }

View File

@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
} }
// persist early // persist early
go func() {
if err := stateManager.PersistState(context.Background()); err != nil { if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err) log.Errorf("failed to persist state: %v", err)
} }
}()
return nil return nil
} }

View File

@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg) err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
return cfg, err return cfg, err
} }
@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
// WriteOutConfig write put the prepared config to the given path // WriteOutConfig write put the prepared config to the given path
func WriteOutConfig(path string, config *Config) error { func WriteOutConfig(path string, config *Config) error {
return util.WriteJson(path, config) return util.WriteJson(context.Background(), path, config)
} }
// createNewConfig creates a new config generating a new Wireguard key and saving to file // createNewConfig creates a new config generating a new Wireguard key and saving to file
@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) {
} }
if updated { if updated {
if err := util.WriteJson(input.ConfigPath, config); err != nil { if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
return nil, err return nil, err
} }
} }

View File

@ -326,9 +326,13 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// persist dns state right away // persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel() defer cancel()
// don't block
go func() {
if err := s.stateManager.PersistState(ctx); err != nil { if err := s.stateManager.PersistState(ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err) log.Errorf("Failed to persist dns state: %v", err)
} }
}()
if s.searchDomainNotifier != nil { if s.searchDomainNotifier != nil {
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains()) s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())

View File

@ -38,7 +38,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@ -641,6 +640,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
} }
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
if e.wgInterface == nil {
return errors.New("wireguard interface is not initialized")
}
if e.wgInterface.Address().String() != conf.Address { if e.wgInterface.Address().String() != conf.Address {
oldAddr := e.wgInterface.Address().String() oldAddr := e.wgInterface.Address().String()
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address) log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)

View File

@ -1,6 +1,7 @@
package routemanager package routemanager
import ( import (
"fmt"
"net/netip" "net/netip"
"testing" "testing"
"time" "time"
@ -227,6 +228,64 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
currentRoute: "route1", currentRoute: "route1",
expectedRouteID: "route1", expectedRouteID: "route1",
}, },
{
name: "relayed routes with latency 0 should maintain previous choice",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: true,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
relayed: true,
latency: 0 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{
name: "p2p routes with latency 0 should maintain previous choice",
statuses: map[route.ID]routerPeerStatus{
"route1": {
connected: true,
relayed: false,
latency: 0 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
latency: 0 * time.Millisecond,
},
},
existingRoutes: map[route.ID]*route.Route{
"route1": {
ID: "route1",
Metric: route.MaxMetric,
Peer: "peer1",
},
"route2": {
ID: "route2",
Metric: route.MaxMetric,
Peer: "peer2",
},
},
currentRoute: "route1",
expectedRouteID: "route1",
},
{ {
name: "current route with bad score should be changed to route with better score", name: "current route with bad score should be changed to route with better score",
statuses: map[route.ID]routerPeerStatus{ statuses: map[route.ID]routerPeerStatus{
@ -287,6 +346,45 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
}, },
} }
// fill the test data with random routes
for _, tc := range testCases {
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
Metric: route.MinMetric,
Peer: fmt.Sprintf("dummy_p1_%d", i),
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
for i := 0; i < 50; i++ {
dummyRoute := &route.Route{
ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
Metric: route.MinMetric,
Peer: fmt.Sprintf("dummy_p1_%d", i),
}
tc.existingRoutes[dummyRoute.ID] = dummyRoute
}
for i := 0; i < 50; i++ {
id := route.ID(fmt.Sprintf("dummy_p1_%d", i))
dummyStatus := routerPeerStatus{
connected: false,
relayed: true,
latency: 0,
}
tc.statuses[id] = dummyStatus
}
for i := 0; i < 50; i++ {
id := route.ID(fmt.Sprintf("dummy_p2_%d", i))
dummyStatus := routerPeerStatus{
connected: false,
relayed: true,
latency: 0,
}
tc.statuses[id] = dummyStatus
}
}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
currentRoute := &route.Route{ currentRoute := &route.Route{

View File

@ -16,6 +16,7 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/util"
) )
// State interface defines the methods that all state types must implement // State interface defines the methods that all state types must implement
@ -178,25 +179,14 @@ func (m *Manager) PersistState(ctx context.Context) error {
return nil return nil
} }
ctx, cancel := context.WithTimeout(ctx, 3*time.Second) ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel() defer cancel()
done := make(chan error, 1) done := make(chan error, 1)
start := time.Now()
go func() { go func() {
data, err := json.MarshalIndent(m.states, "", " ") done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states)
if err != nil {
done <- fmt.Errorf("marshal states: %w", err)
return
}
// nolint:gosec
if err := os.WriteFile(m.filePath, data, 0640); err != nil {
done <- fmt.Errorf("write state file: %w", err)
return
}
done <- nil
}() }()
select { select {
@ -208,7 +198,7 @@ func (m *Manager) PersistState(ctx context.Context) error {
} }
} }
log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty)) log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start))
clear(m.dirty) clear(m.dirty)

View File

@ -4,32 +4,20 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
log "github.com/sirupsen/logrus"
) )
// GetDefaultStatePath returns the path to the state file based on the operating system // GetDefaultStatePath returns the path to the state file based on the operating system
// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist. // It returns an empty string if the path cannot be determined.
func GetDefaultStatePath() string { func GetDefaultStatePath() string {
var path string
switch runtime.GOOS { switch runtime.GOOS {
case "windows": case "windows":
path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json") return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
case "darwin", "linux": case "darwin", "linux":
path = "/var/lib/netbird/state.json" return "/var/lib/netbird/state.json"
case "freebsd", "openbsd", "netbsd", "dragonfly": case "freebsd", "openbsd", "netbsd", "dragonfly":
path = "/var/db/netbird/state.json" return "/var/db/netbird/state.json"
// ios/android don't need state
default:
return ""
} }
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
return "" return ""
}
return path
} }

View File

@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
// It is recommended to call it with locking FileStore.mux // It is recommended to call it with locking FileStore.mux
func (s *FileStore) persist(ctx context.Context, file string) error { func (s *FileStore) persist(ctx context.Context, file string) error {
start := time.Now() start := time.Now()
err := util.WriteJson(file, s) err := util.WriteJson(context.Background(), file, s)
if err != nil { if err != nil {
return err return err
} }

View File

@ -6,11 +6,12 @@ import (
"fmt" "fmt"
"slices" "slices"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@ -27,11 +28,6 @@ func (e *GroupLinkError) Error() string {
// CheckGroupPermissions validates if a user has the necessary permissions to view groups // CheckGroupPermissions validates if a user has the necessary permissions to view groups
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return err return err
@ -41,7 +37,7 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco
return status.NewUserNotPartOfAccountError() return status.NewUserNotPartOfAccountError()
} }
if user.IsRegularUser() && settings.RegularUsersViewBlocked { if user.IsRegularUser() {
return status.NewAdminPermissionError() return status.NewAdminPermissionError()
} }

View File

@ -184,14 +184,26 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
respBody := make([]*api.PeerBatch, 0, len(account.Peers)) peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
for _, peer := range account.Peers { if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groupsMap := map[string]*nbgroup.Group{}
groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
for _, group := range groups {
groupsMap[group.ID] = group
}
respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer) peerToReturn, err := h.checkPeerStatus(peer)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
} }
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) groupMinimumInfo := toGroupsInfo(groupsMap, peer.ID)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
} }
@ -304,7 +316,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
} }
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum groupsInfo := []api.GroupMinimum{}
groupsChecked := make(map[string]struct{}) groupsChecked := make(map[string]struct{})
for _, group := range groups { for _, group := range groups {
_, ok := groupsChecked[group.ID] _, ok := groupsChecked[group.ID]

View File

@ -16,6 +16,8 @@ import (
const ( const (
bufferSize = 8820 bufferSize = 8820
errCloseConn = "failed to close connection to peer: %s"
) )
// Peer represents a peer connection // Peer represents a peer connection
@ -46,6 +48,12 @@ func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle // It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
// the message accordingly. // the message accordingly.
func (p *Peer) Work() { func (p *Peer) Work() {
defer func() {
if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
p.log.Errorf(errCloseConn, err)
}
}()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -97,7 +105,7 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
case messages.MsgTypeClose: case messages.MsgTypeClose:
p.log.Infof("peer exited gracefully") p.log.Infof("peer exited gracefully")
if err := p.conn.Close(); err != nil { if err := p.conn.Close(); err != nil {
log.Errorf("failed to close connection to peer: %s", err) log.Errorf(errCloseConn, err)
} }
default: default:
p.log.Warnf("received unexpected message type: %s", msgType) p.log.Warnf("received unexpected message type: %s", msgType)
@ -121,9 +129,8 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
p.log.Errorf("failed to send close message to peer: %s", p.String()) p.log.Errorf("failed to send close message to peer: %s", p.String())
} }
err = p.conn.Close() if err := p.conn.Close(); err != nil {
if err != nil { p.log.Errorf(errCloseConn, err)
p.log.Errorf("failed to close connection to peer: %s", err)
} }
} }
@ -132,7 +139,7 @@ func (p *Peer) Close() {
defer p.connMu.Unlock() defer p.connMu.Unlock()
if err := p.conn.Close(); err != nil { if err := p.conn.Close(); err != nil {
p.log.Errorf("failed to close connection to peer: %s", err) p.log.Errorf(errCloseConn, err)
} }
} }

View File

@ -15,7 +15,7 @@ import (
) )
// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory // WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
func WriteJsonWithRestrictedPermission(file string, obj interface{}) error { func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file) configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil { if err != nil {
return err return err
@ -26,18 +26,18 @@ func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
return err return err
} }
return writeJson(file, obj, configDir, configFileName) return writeJson(ctx, file, obj, configDir, configFileName)
} }
// WriteJson writes JSON config object to a file creating parent directories if required // WriteJson writes JSON config object to a file creating parent directories if required
// The output JSON is pretty-formatted // The output JSON is pretty-formatted
func WriteJson(file string, obj interface{}) error { func WriteJson(ctx context.Context, file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file) configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil { if err != nil {
return err return err
} }
return writeJson(file, obj, configDir, configFileName) return writeJson(ctx, file, obj, configDir, configFileName)
} }
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file // DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
@ -79,7 +79,11 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
return nil return nil
} }
func writeJson(file string, obj interface{}, configDir string, configFileName string) error { func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error {
// Check context before expensive operations
if ctx.Err() != nil {
return ctx.Err()
}
// make it pretty // make it pretty
bs, err := json.MarshalIndent(obj, "", " ") bs, err := json.MarshalIndent(obj, "", " ")
@ -87,6 +91,10 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st
return err return err
} }
if ctx.Err() != nil {
return ctx.Err()
}
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
if err != nil { if err != nil {
return err return err
@ -111,6 +119,11 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st
return err return err
} }
// Check context again
if ctx.Err() != nil {
return ctx.Err()
}
err = os.Rename(tempFileName, file) err = os.Rename(tempFileName, file)
if err != nil { if err != nil {
return err return err

View File

@ -1,6 +1,7 @@
package util package util
import ( import (
"context"
"crypto/md5" "crypto/md5"
"encoding/hex" "encoding/hex"
"io" "io"
@ -39,7 +40,7 @@ func TestConfigJSON(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
err := WriteJson(tmpDir+"/testconfig.json", tt.config) err := WriteJson(context.Background(), tmpDir+"/testconfig.json", tt.config)
require.NoError(t, err) require.NoError(t, err)
read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
@ -73,7 +74,7 @@ func TestCopyFileContents(t *testing.T) {
src := tmpDir + "/copytest_src" src := tmpDir + "/copytest_src"
dst := tmpDir + "/copytest_dst" dst := tmpDir + "/copytest_dst"
err := WriteJson(src, tt.srcContent) err := WriteJson(context.Background(), src, tt.srcContent)
require.NoError(t, err) require.NoError(t, err)
err = CopyFileContents(src, dst) err = CopyFileContents(src, dst)
@ -127,7 +128,7 @@ func TestHandleConfigFileWithoutFullPath(t *testing.T) {
_ = os.Remove(cfgFile) _ = os.Remove(cfgFile)
}() }()
err := WriteJson(cfgFile, tt.config) err := WriteJson(context.Background(), cfgFile, tt.config)
require.NoError(t, err) require.NoError(t, err)
read, err := ReadJson(cfgFile, &TestConfig{}) read, err := ReadJson(cfgFile, &TestConfig{})