Feature/peer validator (#1553)

Follow up management-integrations changes

move groups to separated packages to avoid circle dependencies
save location information in Login action
This commit is contained in:
Zoltan Papp 2024-03-27 18:48:48 +01:00 committed by GitHub
parent ea2d060f93
commit 2d76b058fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
45 changed files with 790 additions and 351 deletions

View File

@ -13,6 +13,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto" clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server" client "github.com/netbirdio/netbird/client/server"
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
@ -78,7 +79,8 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
if err != nil { if err != nil {
return nil, nil return nil, nil
} }
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) iv, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -21,6 +21,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
@ -1050,7 +1051,8 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@ -2,6 +2,7 @@ package server
import ( import (
"context" "context"
"github.com/netbirdio/management-integrations/integrations"
"net" "net"
"testing" "testing"
"time" "time"
@ -114,7 +115,8 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

4
go.mod
View File

@ -46,6 +46,7 @@ require (
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.5.9 github.com/google/go-cmp v0.5.9
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
github.com/google/martian/v3 v3.0.0
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/google/nftables v0.0.0-20220808154552-2eca00135732
github.com/gopacket/gopacket v1.1.1 github.com/gopacket/gopacket v1.1.1
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
@ -59,8 +60,7 @@ require (
github.com/miekg/dns v1.1.43 github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552 github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98
github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible

7
go.sum
View File

@ -255,6 +255,7 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/martian/v3 v3.0.0 h1:pMen7vLs8nvgEYhywH3KDWJIJTeEr2ULsVWHWYHQyBs=
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A= github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc= github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc=
@ -382,10 +383,8 @@ github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552 h1:yzcQKizAK9YufCHMMCIsr467Dw/OU/4xyHbWizGb1E4= github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 h1:i6AtenTLu/CqhTmj0g1K/GWkkpMJMhQM6Vjs46x25nA=
github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:31FhBNvQ+riHEIu6LSTmqr8IeuSIsGfQffqV4LFmbwA= github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552 h1:OFlzVZtkXCoJsfDKrMigFpuad8ZXTm8epq6x27K0irA=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:B0nMS3es77gOvPYhc0K91fAzTkQLi/jRq5TffUN3klM=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM= github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=

View File

@ -3,6 +3,7 @@ package client
import ( import (
"context" "context"
"net" "net"
"os"
"path/filepath" "path/filepath"
"sync" "sync"
"testing" "testing"
@ -15,6 +16,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
mgmt "github.com/netbirdio/netbird/management/server" mgmt "github.com/netbirdio/netbird/management/server"
@ -30,6 +32,12 @@ import (
const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
func TestMain(m *testing.M) {
_ = util.InitLog("debug", "console")
code := m.Run()
os.Exit(code)
}
func startManagement(t *testing.T) (*grpc.Server, net.Listener) { func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Helper() t.Helper()
level, _ := log.ParseLevel("debug") level, _ := log.ParseLevel("debug")
@ -60,7 +68,8 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
peersUpdateManager := mgmt.NewPeersUpdateManager(nil) peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -31,6 +31,7 @@ import (
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
@ -172,8 +173,12 @@ var (
log.Infof("geo location service has been initialized from %s", config.Datadir) log.Infof("geo location service has been initialized from %s", config.Datadir)
} }
integratedPeerValidator, err := integrations.NewIntegratedValidator(eventStore)
if err != nil {
return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
}
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled) dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator)
if err != nil { if err != nil {
return fmt.Errorf("failed to build default manager: %v", err) return fmt.Errorf("failed to build default manager: %v", err)
} }
@ -323,6 +328,7 @@ var (
SetupCloseHandler() SetupCloseHandler()
<-stopCh <-stopCh
integratedPeerValidator.Stop()
if geo != nil { if geo != nil {
_ = geo.Stop() _ = geo.Stop()
} }

View File

@ -21,14 +21,15 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/additions"
"github.com/netbirdio/netbird/base62" "github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integrated_validator"
"github.com/netbirdio/netbird/management/server/integration_reference"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
@ -85,12 +86,12 @@ type AccountManager interface {
GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
UpdatePeerSSHKey(peerID string, sshKey string) error UpdatePeerSSHKey(peerID string, sshKey string) error
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
GetGroup(accountId, groupID, userID string) (*Group, error) GetGroup(accountId, groupID, userID string) (*nbgroup.Group, error)
GetAllGroups(accountID, userID string) ([]*Group, error) GetAllGroups(accountID, userID string) ([]*nbgroup.Group, error)
GetGroupByName(groupName, accountID string) (*Group, error) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error)
SaveGroup(accountID, userID string, group *Group) error SaveGroup(accountID, userID string, group *nbgroup.Group) error
DeleteGroup(accountId, userId, groupID string) error DeleteGroup(accountId, userId, groupID string) error
ListGroups(accountId string) ([]*Group, error) ListGroups(accountId string) ([]*nbgroup.Group, error)
GroupAddPeer(accountId, groupID, peerID string) error GroupAddPeer(accountId, groupID, peerID string) error
GroupDeletePeer(accountId, groupID, peerID string) error GroupDeletePeer(accountId, groupID, peerID string) error
GetPolicy(accountID, policyID, userID string) (*Policy, error) GetPolicy(accountID, policyID, userID string) (*Policy, error)
@ -124,6 +125,9 @@ type AccountManager interface {
DeletePostureChecks(accountID, postureChecksID, userID string) error DeletePostureChecks(accountID, postureChecksID, userID string) error
ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) ListPostureChecks(accountID, userID string) ([]*posture.Checks, error)
GetIdpManager() idp.Manager GetIdpManager() idp.Manager
UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error
GroupValidation(accountId string, groups []string) (bool, error)
GetValidatedPeers(account *Account) (map[string]struct{}, error)
} }
type DefaultAccountManager struct { type DefaultAccountManager struct {
@ -152,6 +156,8 @@ type DefaultAccountManager struct {
// userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account // userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account
userDeleteFromIDPEnabled bool userDeleteFromIDPEnabled bool
integratedPeerValidator integrated_validator.IntegratedValidator
} }
// Settings represents Account settings structure that can be modified via API and Dashboard // Settings represents Account settings structure that can be modified via API and Dashboard
@ -218,8 +224,8 @@ type Account struct {
PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"`
Users map[string]*User `gorm:"-"` Users map[string]*User `gorm:"-"`
UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"`
Groups map[string]*Group `gorm:"-"` Groups map[string]*nbgroup.Group `gorm:"-"`
GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"` GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
Routes map[string]*route.Route `gorm:"-"` Routes map[string]*route.Route `gorm:"-"`
RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`
@ -247,7 +253,7 @@ type UserInfo struct {
NonDeletable bool `json:"non_deletable"` NonDeletable bool `json:"non_deletable"`
LastLogin time.Time `json:"last_login"` LastLogin time.Time `json:"last_login"`
Issued string `json:"issued"` Issued string `json:"issued"`
IntegrationReference IntegrationReference `json:"-"` IntegrationReference integration_reference.IntegrationReference `json:"-"`
Permissions UserPermissions `json:"permissions"` Permissions UserPermissions `json:"permissions"`
} }
@ -372,25 +378,26 @@ func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route {
} }
// GetGroup returns a group by ID if exists, nil otherwise // GetGroup returns a group by ID if exists, nil otherwise
func (a *Account) GetGroup(groupID string) *Group { func (a *Account) GetGroup(groupID string) *nbgroup.Group {
return a.Groups[groupID] return a.Groups[groupID]
} }
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise // GetPeerNetworkMap returns a group by ID if exists, nil otherwise
func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap { func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap {
peer := a.Peers[peerID] peer := a.Peers[peerID]
if peer == nil { if peer == nil {
return &NetworkMap{ return &NetworkMap{
Network: a.Network.Copy(), Network: a.Network.Copy(),
} }
} }
validatedPeers := additions.ValidatePeers([]*nbpeer.Peer{peer})
if len(validatedPeers) == 0 { if _, ok := validatedPeersMap[peerID]; !ok {
return &NetworkMap{ return &NetworkMap{
Network: a.Network.Copy(), Network: a.Network.Copy(),
} }
} }
aclPeers, firewallRules := a.getPeerConnectionResources(peerID)
aclPeers, firewallRules := a.getPeerConnectionResources(peerID, validatedPeersMap)
// exclude expired peers // exclude expired peers
var peersToConnect []*nbpeer.Peer var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer var expiredPeers []*nbpeer.Peer
@ -564,7 +571,7 @@ func (a *Account) FindUser(userID string) (*User, error) {
} }
// FindGroupByName looks for a given group in the Account by name or returns error if the group wasn't found. // FindGroupByName looks for a given group in the Account by name or returns error if the group wasn't found.
func (a *Account) FindGroupByName(groupName string) (*Group, error) { func (a *Account) FindGroupByName(groupName string) (*nbgroup.Group, error) {
for _, group := range a.Groups { for _, group := range a.Groups {
if group.Name == groupName { if group.Name == groupName {
return group, nil return group, nil
@ -583,6 +590,20 @@ func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) {
return key, nil return key, nil
} }
// GetPeerGroupsList return with the list of groups ID.
func (a *Account) GetPeerGroupsList(peerID string) []string {
var grps []string
for groupID, group := range a.Groups {
for _, id := range group.Peers {
if id == peerID {
grps = append(grps, groupID)
break
}
}
}
return grps
}
func (a *Account) getUserGroups(userID string) ([]string, error) { func (a *Account) getUserGroups(userID string) ([]string, error) {
user, err := a.FindUser(userID) user, err := a.FindUser(userID)
if err != nil { if err != nil {
@ -660,7 +681,7 @@ func (a *Account) Copy() *Account {
setupKeys[id] = key.Copy() setupKeys[id] = key.Copy()
} }
groups := map[string]*Group{} groups := map[string]*nbgroup.Group{}
for id, group := range a.Groups { for id, group := range a.Groups {
groups[id] = group.Copy() groups[id] = group.Copy()
} }
@ -713,7 +734,7 @@ func (a *Account) Copy() *Account {
} }
} }
func (a *Account) GetGroupAll() (*Group, error) { func (a *Account) GetGroupAll() (*nbgroup.Group, error) {
for _, g := range a.Groups { for _, g := range a.Groups {
if g.Name == "All" { if g.Name == "All" {
return g, nil return g, nil
@ -734,7 +755,7 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool {
return false return false
} }
existedGroupsByName := make(map[string]*Group) existedGroupsByName := make(map[string]*nbgroup.Group)
for _, group := range a.Groups { for _, group := range a.Groups {
existedGroupsByName[group.Name] = group existedGroupsByName[group.Name] = group
} }
@ -743,7 +764,7 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool {
removed := 0 removed := 0
jwtAutoGroups := make(map[string]struct{}) jwtAutoGroups := make(map[string]struct{})
for i, id := range user.AutoGroups { for i, id := range user.AutoGroups {
if group, ok := a.Groups[id]; ok && group.Issued == GroupIssuedJWT { if group, ok := a.Groups[id]; ok && group.Issued == nbgroup.GroupIssuedJWT {
jwtAutoGroups[group.Name] = struct{}{} jwtAutoGroups[group.Name] = struct{}{}
user.AutoGroups = append(user.AutoGroups[:i-removed], user.AutoGroups[i-removed+1:]...) user.AutoGroups = append(user.AutoGroups[:i-removed], user.AutoGroups[i-removed+1:]...)
removed++ removed++
@ -756,15 +777,15 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool {
for _, name := range groupsNames { for _, name := range groupsNames {
group, ok := existedGroupsByName[name] group, ok := existedGroupsByName[name]
if !ok { if !ok {
group = &Group{ group = &nbgroup.Group{
ID: xid.New().String(), ID: xid.New().String(),
Name: name, Name: name,
Issued: GroupIssuedJWT, Issued: nbgroup.GroupIssuedJWT,
} }
a.Groups[group.ID] = group a.Groups[group.ID] = group
} }
// only JWT groups will be synced // only JWT groups will be synced
if group.Issued == GroupIssuedJWT { if group.Issued == nbgroup.GroupIssuedJWT {
user.AutoGroups = append(user.AutoGroups, group.ID) user.AutoGroups = append(user.AutoGroups, group.ID)
if _, ok := jwtAutoGroups[name]; !ok { if _, ok := jwtAutoGroups[name]; !ok {
modified = true modified = true
@ -837,6 +858,7 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation, singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation,
userDeleteFromIDPEnabled bool, userDeleteFromIDPEnabled bool,
integratedPeerValidator integrated_validator.IntegratedValidator,
) (*DefaultAccountManager, error) { ) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{ am := &DefaultAccountManager{
Store: store, Store: store,
@ -850,6 +872,7 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage
eventStore: eventStore, eventStore: eventStore,
peerLoginExpiry: NewDefaultScheduler(), peerLoginExpiry: NewDefaultScheduler(),
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
integratedPeerValidator: integratedPeerValidator,
} }
allAccounts := store.GetAllAccounts() allAccounts := store.GetAllAccounts()
// enable single account mode only if configured by user and number of existing accounts is not grater than 1 // enable single account mode only if configured by user and number of existing accounts is not grater than 1
@ -906,6 +929,8 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage
}() }()
} }
am.integratedPeerValidator.SetPeerInvalidationListener(am.onPeersInvalidated)
return am, nil return am, nil
} }
@ -948,7 +973,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account")
} }
err = additions.ValidateExtraSettings(newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID, am.eventStore) err = am.integratedPeerValidator.ValidateExtraSettings(newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1823,18 +1848,27 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut
return nil return nil
} }
func (am *DefaultAccountManager) onPeersInvalidated(accountID string) {
updatedAccount, err := am.Store.GetAccount(accountID)
if err != nil {
log.Errorf("failed to get account %s: %v", accountID, err)
return
}
am.updateAccountPeers(updatedAccount)
}
// addAllGroup to account object if it doesn't exist // addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error { func addAllGroup(account *Account) error {
if len(account.Groups) == 0 { if len(account.Groups) == 0 {
allGroup := &Group{ allGroup := &nbgroup.Group{
ID: xid.New().String(), ID: xid.New().String(),
Name: "All", Name: "All",
Issued: GroupIssuedAPI, Issued: nbgroup.GroupIssuedAPI,
} }
for _, peer := range account.Peers { for _, peer := range account.Peers {
allGroup.Peers = append(allGroup.Peers, peer.ID) allGroup.Peers = append(allGroup.Peers, peer.ID)
} }
account.Groups = map[string]*Group{allGroup.ID: allGroup} account.Groups = map[string]*nbgroup.Group{allGroup.ID: allGroup}
id := xid.New().String() id := xid.New().String()

View File

@ -3,11 +3,17 @@ package account
type ExtraSettings struct { type ExtraSettings struct {
// PeerApprovalEnabled enables or disables the need for peers bo be approved by an administrator // PeerApprovalEnabled enables or disables the need for peers bo be approved by an administrator
PeerApprovalEnabled bool PeerApprovalEnabled bool
// IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations
IntegratedValidatorGroups []string `gorm:"serializer:json"`
} }
// Copy copies the ExtraSettings struct // Copy copies the ExtraSettings struct
func (e *ExtraSettings) Copy() *ExtraSettings { func (e *ExtraSettings) Copy() *ExtraSettings {
var cpGroup []string
return &ExtraSettings{ return &ExtraSettings{
PeerApprovalEnabled: e.PeerApprovalEnabled, PeerApprovalEnabled: e.PeerApprovalEnabled,
IntegratedValidatorGroups: append(cpGroup, e.IntegratedValidatorGroups...),
} }
} }

View File

@ -12,20 +12,57 @@ import (
"time" "time"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/route"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/route"
) )
type MocIntegratedValidator struct {
}
func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil
}
func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
return update, nil
}
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{})
for _, peer := range peers {
validatedPeers[peer.ID] = struct{}{}
}
return validatedPeers, nil
}
func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
return peer
}
func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool) {
return false, false
}
func (MocIntegratedValidator) PeerDeleted(_, _ string) error {
return nil
}
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
}
func (MocIntegratedValidator) Stop() {
}
func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) { func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) {
t.Helper() t.Helper()
peer := &nbpeer.Peer{ peer := &nbpeer.Peer{
@ -367,7 +404,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
account.Groups[all.ID].Peers = append(account.Groups[all.ID].Peers, peer.ID) account.Groups[all.ID].Peers = append(account.Groups[all.ID].Peers, peer.ID)
} }
networkMap := account.GetPeerNetworkMap(testCase.peerID, "netbird.io") validatedPeers := map[string]struct{}{}
for p := range account.Peers {
validatedPeers[p] = struct{}{}
}
networkMap := account.GetPeerNetworkMap(testCase.peerID, "netbird.io", validatedPeers)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
} }
@ -667,7 +709,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 3, "groups should be added to the account") require.Len(t, account.Groups, 3, "groups should be added to the account")
groupsByNames := map[string]*Group{} groupsByNames := map[string]*group.Group{}
for _, g := range account.Groups { for _, g := range account.Groups {
groupsByNames[g.Name] = g groupsByNames[g.Name] = g
} }
@ -675,12 +717,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
g1, ok := groupsByNames["group1"] g1, ok := groupsByNames["group1"]
require.True(t, ok, "group1 should be added to the account") require.True(t, ok, "group1 should be added to the account")
require.Equal(t, g1.Name, "group1", "group1 name should match") require.Equal(t, g1.Name, "group1", "group1 name should match")
require.Equal(t, g1.Issued, GroupIssuedJWT, "group1 issued should match") require.Equal(t, g1.Issued, group.GroupIssuedJWT, "group1 issued should match")
g2, ok := groupsByNames["group2"] g2, ok := groupsByNames["group2"]
require.True(t, ok, "group2 should be added to the account") require.True(t, ok, "group2 should be added to the account")
require.Equal(t, g2.Name, "group2", "group2 name should match") require.Equal(t, g2.Name, "group2", "group2 name should match")
require.Equal(t, g2.Issued, GroupIssuedJWT, "group2 issued should match") require.Equal(t, g2.Issued, group.GroupIssuedJWT, "group2 issued should match")
}) })
} }
@ -800,7 +842,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
t.Fatalf("expected to create an account for a user %s", userId) t.Fatalf("expected to create an account for a user %s", userId)
} }
if account.Domain != domain { if account != nil && account.Domain != domain {
t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain) t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain)
} }
@ -815,7 +857,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
t.Fatalf("expected to get an account for a user %s", userId) t.Fatalf("expected to get an account for a user %s", userId)
} }
if account.Domain != domain { if account != nil && account.Domain != domain {
t.Errorf("updating domain. expected %s got %s", domain, account.Domain) t.Errorf("updating domain. expected %s got %s", domain, account.Domain)
} }
} }
@ -835,13 +877,12 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
} }
if account == nil { if account == nil {
t.Fatalf("expected to create an account for a user %s", userId) t.Fatalf("expected to create an account for a user %s", userId)
return
} }
accountId := account.Id _, err = manager.GetAccountByUserOrAccountID("", account.Id, "")
_, err = manager.GetAccountByUserOrAccountID("", accountId, "")
if err != nil { if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId) t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id)
} }
_, err = manager.GetAccountByUserOrAccountID("", "", "") _, err = manager.GetAccountByUserOrAccountID("", "", "")
@ -1124,7 +1165,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(peer1.ID) updMsg := manager.peersUpdateManager.CreateChannel(peer1.ID)
defer manager.peersUpdateManager.CloseChannel(peer1.ID) defer manager.peersUpdateManager.CloseChannel(peer1.ID)
group := Group{ group := group.Group{
ID: "group-id", ID: "group-id",
Name: "GroupA", Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID}, Peers: []string{peer1.ID, peer2.ID, peer3.ID},
@ -1417,7 +1458,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
Peers: map[string]*nbpeer.Peer{ Peers: map[string]*nbpeer.Peer{
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
}, },
Groups: map[string]*Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
Routes: map[string]*route.Route{ Routes: map[string]*route.Route{
"route-1": { "route-1": {
ID: "route-1", ID: "route-1",
@ -1518,7 +1559,7 @@ func TestAccount_Copy(t *testing.T) {
}, },
}, },
}, },
Groups: map[string]*Group{ Groups: map[string]*group.Group{
"group1": { "group1": {
ID: "group1", ID: "group1",
Peers: []string{"peer1"}, Peers: []string{"peer1"},
@ -2112,8 +2153,8 @@ func TestAccount_SetJWTGroups(t *testing.T) {
"peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer4": {ID: "peer4", Key: "key4", UserID: "user2"},
"peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"},
}, },
Groups: map[string]*Group{ Groups: map[string]*group.Group{
"group1": {ID: "group1", Name: "group1", Issued: GroupIssuedAPI, Peers: []string{}}, "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}},
}, },
Settings: &Settings{GroupsPropagationEnabled: true}, Settings: &Settings{GroupsPropagationEnabled: true},
Users: map[string]*User{ Users: map[string]*User{
@ -2160,10 +2201,10 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) {
"peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer4": {ID: "peer4", Key: "key4", UserID: "user2"},
"peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"},
}, },
Groups: map[string]*Group{ Groups: map[string]*group.Group{
"group1": {ID: "group1", Name: "group1", Issued: GroupIssuedAPI, Peers: []string{}}, "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}},
"group2": {ID: "group2", Name: "group2", Issued: GroupIssuedAPI, Peers: []string{}}, "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{}},
"group3": {ID: "group3", Name: "group3", Issued: GroupIssuedAPI, Peers: []string{}}, "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{}},
}, },
Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}},
} }
@ -2196,10 +2237,10 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
"peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer4": {ID: "peer4", Key: "key4", UserID: "user2"},
"peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"},
}, },
Groups: map[string]*Group{ Groups: map[string]*group.Group{
"group1": {ID: "group1", Name: "group1", Issued: GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}},
"group2": {ID: "group2", Name: "group2", Issued: GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}},
"group3": {ID: "group3", Name: "group3", Issued: GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{"peer4", "peer5"}},
}, },
Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}},
} }
@ -2223,7 +2264,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err return nil, err
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false) return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
} }
func createStore(t *testing.T) (Store, error) { func createStore(t *testing.T) (Store, error) {

View File

@ -8,6 +8,7 @@ import (
"github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
@ -193,7 +194,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err return nil, err
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false) return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
} }
func createDNSStore(t *testing.T) (Store, error) { func createDNSStore(t *testing.T) (Store, error) {
@ -278,13 +279,13 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
return nil, err return nil, err
} }
newGroup1 := &Group{ newGroup1 := &group.Group{
ID: dnsGroup1ID, ID: dnsGroup1ID,
Peers: []string{peer1.ID}, Peers: []string{peer1.ID},
Name: dnsGroup1ID, Name: dnsGroup1ID,
} }
newGroup2 := &Group{ newGroup2 := &group.Group{
ID: dnsGroup2ID, ID: dnsGroup2ID,
Name: dnsGroup2ID, Name: dnsGroup2ID,
} }

View File

@ -165,7 +165,7 @@ func (e *EphemeralManager) cleanup() {
log.Debugf("delete ephemeral peer: %s", id) log.Debugf("delete ephemeral peer: %s", id)
err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator) err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator)
if err != nil { if err != nil {
log.Tracef("failed to delete ephemeral peer: %s", err) log.Errorf("failed to delete ephemeral peer: %s", err)
} }
} }
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
@ -170,7 +171,7 @@ func restore(file string) (*FileStore, error) {
// Set API as issuer for groups which has not this field // Set API as issuer for groups which has not this field
for _, group := range account.Groups { for _, group := range account.Groups {
if group.Issued == "" { if group.Issued == "" {
group.Issued = GroupIssuedAPI group.Issued = nbgroup.GroupIssuedAPI
} }
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@ -188,7 +189,7 @@ func TestStore(t *testing.T) {
Name: "peer name", Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
} }
account.Groups["all"] = &Group{ account.Groups["all"] = &group.Group{
ID: "all", ID: "all",
Name: "all", Name: "all",
Peers: []string{"testpeer"}, Peers: []string{"testpeer"},
@ -320,7 +321,7 @@ func TestRestoreGroups_Migration(t *testing.T) {
// create default group // create default group
account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
account.Groups = map[string]*Group{ account.Groups = map[string]*group.Group{
"cfefqs706sqkneg59g3g": { "cfefqs706sqkneg59g3g": {
ID: "cfefqs706sqkneg59g3g", ID: "cfefqs706sqkneg59g3g",
Name: "All", Name: "All",
@ -336,7 +337,7 @@ func TestRestoreGroups_Migration(t *testing.T) {
account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups") require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups")
require.Equal(t, GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark") require.Equal(t, group.GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark")
} }
func TestGetAccountByPrivateDomain(t *testing.T) { func TestGetAccountByPrivateDomain(t *testing.T) {
@ -384,6 +385,7 @@ func TestFileStore_GetAccount(t *testing.T) {
expected := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] expected := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
if expected == nil { if expected == nil {
t.Fatalf("expected account doesn't exist") t.Fatalf("expected account doesn't exist")
return
} }
account, err := store.GetAccount(expected.Id) account, err := store.GetAccount(expected.Id)

View File

@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
@ -19,51 +20,8 @@ func (e *GroupLinkError) Error() string {
return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name) return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name)
} }
const (
GroupIssuedAPI = "api"
GroupIssuedJWT = "jwt"
GroupIssuedIntegration = "integration"
)
// Group of the peers for ACL
type Group struct {
// ID of the group
ID string
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
// Name visible in the UI
Name string
// Issued defines how this group was created (enum of "api", "integration" or "jwt")
Issued string
// Peers list of the group
Peers []string `gorm:"serializer:json"`
IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
}
// EventMeta returns activity event meta related to the group
func (g *Group) EventMeta() map[string]any {
return map[string]any{"name": g.Name}
}
func (g *Group) Copy() *Group {
group := &Group{
ID: g.ID,
Name: g.Name,
Issued: g.Issued,
Peers: make([]string, len(g.Peers)),
IntegrationReference: g.IntegrationReference,
}
copy(group.Peers, g.Peers)
return group
}
// GetGroup object of the peers // GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*Group, error) { func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@ -90,7 +48,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*G
} }
// GetAllGroups returns all groups in an account // GetAllGroups returns all groups in an account
func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*Group, error) { func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@ -108,7 +66,7 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) (
return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
} }
groups := make([]*Group, 0, len(account.Groups)) groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups { for _, item := range account.Groups {
groups = append(groups, item) groups = append(groups, item)
} }
@ -117,7 +75,7 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) (
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*Group, error) { func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@ -126,7 +84,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*G
return nil, err return nil, err
} }
matchingGroups := make([]*Group, 0) matchingGroups := make([]*nbgroup.Group, 0)
for _, group := range account.Groups { for _, group := range account.Groups {
if group.Name == groupName { if group.Name == groupName {
matchingGroups = append(matchingGroups, group) matchingGroups = append(matchingGroups, group)
@ -138,7 +96,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*G
} }
maxPeers := -1 maxPeers := -1
var groupWithMostPeers *Group var groupWithMostPeers *nbgroup.Group
for i, group := range matchingGroups { for i, group := range matchingGroups {
if len(group.Peers) > maxPeers { if len(group.Peers) > maxPeers {
maxPeers = len(group.Peers) maxPeers = len(group.Peers)
@ -150,7 +108,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*G
} }
// SaveGroup object of the peers // SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *Group) error { func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@ -159,11 +117,11 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *G
return err return err
} }
if newGroup.ID == "" && newGroup.Issued != GroupIssuedAPI { if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
} }
if newGroup.ID == "" && newGroup.Issued == GroupIssuedAPI { if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := account.FindGroupByName(newGroup.Name) existingGroup, err := account.FindGroupByName(newGroup.Name)
if err != nil { if err != nil {
@ -270,7 +228,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
} }
// disable a deleting integration group if the initiator is not an admin service user // disable a deleting integration group if the initiator is not an admin service user
if g.Issued == GroupIssuedIntegration { if g.Issued == nbgroup.GroupIssuedIntegration {
executingUser := account.Users[userId] executingUser := account.Users[userId]
if executingUser == nil { if executingUser == nil {
return status.Errorf(status.NotFound, "user not found") return status.Errorf(status.NotFound, "user not found")
@ -340,6 +298,15 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
} }
} }
// check integrated peer validator groups
if account.Settings.Extra != nil {
for _, integratedPeerValidatorGroups := range account.Settings.Extra.IntegratedValidatorGroups {
if groupID == integratedPeerValidatorGroups {
return &GroupLinkError{"integrated validator", g.Name}
}
}
}
delete(account.Groups, groupID) delete(account.Groups, groupID)
account.Network.IncSerial() account.Network.IncSerial()
@ -355,7 +322,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
} }
// ListGroups objects of the peers // ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) { func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@ -364,7 +331,7 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error)
return nil, err return nil, err
} }
groups := make([]*Group, 0, len(account.Groups)) groups := make([]*nbgroup.Group, 0, len(account.Groups))
for _, item := range account.Groups { for _, item := range account.Groups {
groups = append(groups, item) groups = append(groups, item)
} }

View File

@ -0,0 +1,46 @@
package group
import "github.com/netbirdio/netbird/management/server/integration_reference"
const (
GroupIssuedAPI = "api"
GroupIssuedJWT = "jwt"
GroupIssuedIntegration = "integration"
)
// Group of the peers for ACL
type Group struct {
// ID of the group
ID string
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
// Name visible in the UI
Name string
// Issued defines how this group was created (enum of "api", "integration" or "jwt")
Issued string
// Peers list of the group
Peers []string `gorm:"serializer:json"`
IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
}
// EventMeta returns activity event meta related to the group
func (g *Group) EventMeta() map[string]any {
return map[string]any{"name": g.Name}
}
func (g *Group) Copy() *Group {
group := &Group{
ID: g.ID,
Name: g.Name,
Issued: g.Issued,
Peers: make([]string, len(g.Peers)),
IntegrationReference: g.IntegrationReference,
}
copy(group.Peers, g.Peers)
return group
}

View File

@ -5,6 +5,7 @@ import (
"testing" "testing"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -24,22 +25,22 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
t.Error("failed to init testing account") t.Error("failed to init testing account")
} }
for _, group := range account.Groups { for _, group := range account.Groups {
group.Issued = GroupIssuedIntegration group.Issued = nbgroup.GroupIssuedIntegration
err = am.SaveGroup(account.Id, groupAdminUserID, group) err = am.SaveGroup(account.Id, groupAdminUserID, group)
if err != nil { if err != nil {
t.Errorf("should allow to create %s groups", GroupIssuedIntegration) t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration)
} }
} }
for _, group := range account.Groups { for _, group := range account.Groups {
group.Issued = GroupIssuedJWT group.Issued = nbgroup.GroupIssuedJWT
err = am.SaveGroup(account.Id, groupAdminUserID, group) err = am.SaveGroup(account.Id, groupAdminUserID, group)
if err != nil { if err != nil {
t.Errorf("should allow to create %s groups", GroupIssuedJWT) t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT)
} }
} }
for _, group := range account.Groups { for _, group := range account.Groups {
group.Issued = GroupIssuedAPI group.Issued = nbgroup.GroupIssuedAPI
group.ID = "" group.ID = ""
err = am.SaveGroup(account.Id, groupAdminUserID, group) err = am.SaveGroup(account.Id, groupAdminUserID, group)
if err == nil { if err == nil {
@ -129,51 +130,51 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
accountID := "testingAcc" accountID := "testingAcc"
domain := "example.com" domain := "example.com"
groupForRoute := &Group{ groupForRoute := &nbgroup.Group{
ID: "grp-for-route", ID: "grp-for-route",
AccountID: "account-id", AccountID: "account-id",
Name: "Group for route", Name: "Group for route",
Issued: GroupIssuedAPI, Issued: nbgroup.GroupIssuedAPI,
Peers: make([]string, 0), Peers: make([]string, 0),
} }
groupForNameServerGroups := &Group{ groupForNameServerGroups := &nbgroup.Group{
ID: "grp-for-name-server-grp", ID: "grp-for-name-server-grp",
AccountID: "account-id", AccountID: "account-id",
Name: "Group for name server groups", Name: "Group for name server groups",
Issued: GroupIssuedAPI, Issued: nbgroup.GroupIssuedAPI,
Peers: make([]string, 0), Peers: make([]string, 0),
} }
groupForPolicies := &Group{ groupForPolicies := &nbgroup.Group{
ID: "grp-for-policies", ID: "grp-for-policies",
AccountID: "account-id", AccountID: "account-id",
Name: "Group for policies", Name: "Group for policies",
Issued: GroupIssuedAPI, Issued: nbgroup.GroupIssuedAPI,
Peers: make([]string, 0), Peers: make([]string, 0),
} }
groupForSetupKeys := &Group{ groupForSetupKeys := &nbgroup.Group{
ID: "grp-for-keys", ID: "grp-for-keys",
AccountID: "account-id", AccountID: "account-id",
Name: "Group for setup keys", Name: "Group for setup keys",
Issued: GroupIssuedAPI, Issued: nbgroup.GroupIssuedAPI,
Peers: make([]string, 0), Peers: make([]string, 0),
} }
groupForUsers := &Group{ groupForUsers := &nbgroup.Group{
ID: "grp-for-users", ID: "grp-for-users",
AccountID: "account-id", AccountID: "account-id",
Name: "Group for users", Name: "Group for users",
Issued: GroupIssuedAPI, Issued: nbgroup.GroupIssuedAPI,
Peers: make([]string, 0), Peers: make([]string, 0),
} }
groupForIntegration := &Group{ groupForIntegration := &nbgroup.Group{
ID: "grp-for-integration", ID: "grp-for-integration",
AccountID: "account-id", AccountID: "account-id",
Name: "Group for users integration", Name: "Group for users integration",
Issued: GroupIssuedIntegration, Issued: nbgroup.GroupIssuedIntegration,
Peers: make([]string, 0), Peers: make([]string, 0),
} }

View File

@ -361,6 +361,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
Meta: extractPeerMeta(loginReq), Meta: extractPeerMeta(loginReq),
UserID: userID, UserID: userID,
SetupKey: loginReq.GetSetupKey(), SetupKey: loginReq.GetSetupKey(),
ConnectionIP: realIP,
}) })
if err != nil { if err != nil {

View File

@ -355,6 +355,7 @@ components:
- user_id - user_id
- version - version
- ui_version - ui_version
- approval_required
AccessiblePeer: AccessiblePeer:
allOf: allOf:
- $ref: '#/components/schemas/PeerMinimum' - $ref: '#/components/schemas/PeerMinimum'

View File

@ -470,7 +470,7 @@ type Peer struct {
AccessiblePeers []AccessiblePeer `json:"accessible_peers"` AccessiblePeers []AccessiblePeer `json:"accessible_peers"`
// ApprovalRequired (Cloud only) Indicates whether peer needs approval // ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired *bool `json:"approval_required,omitempty"` ApprovalRequired bool `json:"approval_required"`
// CityName Commonly used English name of the city // CityName Commonly used English name of the city
CityName CityName `json:"city_name"` CityName CityName `json:"city_name"`
@ -539,7 +539,7 @@ type Peer struct {
// PeerBase defines model for PeerBase. // PeerBase defines model for PeerBase.
type PeerBase struct { type PeerBase struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval // ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired *bool `json:"approval_required,omitempty"` ApprovalRequired bool `json:"approval_required"`
// CityName Commonly used English name of the city // CityName Commonly used English name of the city
CityName CityName `json:"city_name"` CityName CityName `json:"city_name"`
@ -611,7 +611,7 @@ type PeerBatch struct {
AccessiblePeersCount int `json:"accessible_peers_count"` AccessiblePeersCount int `json:"accessible_peers_count"`
// ApprovalRequired (Cloud only) Indicates whether peer needs approval // ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired *bool `json:"approval_required,omitempty"` ApprovalRequired bool `json:"approval_required"`
// CityName Commonly used English name of the city // CityName Commonly used English name of the city
CityName CityName `json:"city_name"` CityName CityName `json:"city_name"`

View File

@ -4,15 +4,15 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/gorilla/mux" "github.com/gorilla/mux"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
) )
// GroupsHandler is a handler that returns groups of the account // GroupsHandler is a handler that returns groups of the account
@ -110,7 +110,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
} else { } else {
peers = *req.Peers peers = *req.Peers
} }
group := server.Group{ group := nbgroup.Group{
ID: groupID, ID: groupID,
Name: req.Name, Name: req.Name,
Peers: peers, Peers: peers,
@ -154,10 +154,10 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
} else { } else {
peers = *req.Peers peers = *req.Peers
} }
group := server.Group{ group := nbgroup.Group{
Name: req.Name, Name: req.Name,
Peers: peers, Peers: peers,
Issued: server.GroupIssuedAPI, Issued: nbgroup.GroupIssuedAPI,
} }
err = h.accountManager.SaveGroup(account.Id, user.Id, &group) err = h.accountManager.SaveGroup(account.Id, user.Id, &group)
@ -240,7 +240,7 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
} }
} }
func toGroupResponse(account *server.Account, group *server.Group) *api.Group { func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group {
cache := make(map[string]api.PeerMinimum) cache := make(map[string]api.PeerMinimum)
gr := api.Group{ gr := api.Group{
Id: group.ID, Id: group.ID,

View File

@ -15,6 +15,7 @@ import (
"github.com/magiconair/properties/assert" "github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@ -28,30 +29,30 @@ var TestPeers = map[string]*nbpeer.Peer{
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
} }
func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandler { func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
return &GroupsHandler{ return &GroupsHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(accountID, userID string, group *server.Group) error { SaveGroupFunc: func(accountID, userID string, group *nbgroup.Group) error {
if !strings.HasPrefix(group.ID, "id-") { if !strings.HasPrefix(group.ID, "id-") {
group.ID = "id-was-set" group.ID = "id-was-set"
} }
return nil return nil
}, },
GetGroupFunc: func(_, groupID, _ string) (*server.Group, error) { GetGroupFunc: func(_, groupID, _ string) (*nbgroup.Group, error) {
if groupID != "idofthegroup" { if groupID != "idofthegroup" {
return nil, status.Errorf(status.NotFound, "not found") return nil, status.Errorf(status.NotFound, "not found")
} }
if groupID == "id-jwt-group" { if groupID == "id-jwt-group" {
return &server.Group{ return &nbgroup.Group{
ID: "id-jwt-group", ID: "id-jwt-group",
Name: "Default Group", Name: "Default Group",
Issued: server.GroupIssuedJWT, Issued: nbgroup.GroupIssuedJWT,
}, nil }, nil
} }
return &server.Group{ return &nbgroup.Group{
ID: "idofthegroup", ID: "idofthegroup",
Name: "Group", Name: "Group",
Issued: server.GroupIssuedAPI, Issued: nbgroup.GroupIssuedAPI,
}, nil }, nil
}, },
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
@ -62,10 +63,10 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle
Users: map[string]*server.User{ Users: map[string]*server.User{
user.Id: user, user.Id: user,
}, },
Groups: map[string]*server.Group{ Groups: map[string]*nbgroup.Group{
"id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: server.GroupIssuedJWT}, "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT},
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: server.GroupIssuedAPI}, "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI},
"id-all": {ID: "id-all", Name: "All", Issued: server.GroupIssuedAPI}, "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI},
}, },
}, user, nil }, user, nil
}, },
@ -118,7 +119,7 @@ func TestGetGroup(t *testing.T) {
}, },
} }
group := &server.Group{ group := &nbgroup.Group{
ID: "idofthegroup", ID: "idofthegroup",
Name: "Group", Name: "Group",
} }
@ -153,7 +154,7 @@ func TestGetGroup(t *testing.T) {
t.Fatalf("I don't know what I expected; %v", err) t.Fatalf("I don't know what I expected; %v", err)
} }
got := &server.Group{} got := &nbgroup.Group{}
if err = json.Unmarshal(content, &got); err != nil { if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err) t.Fatalf("Sent content is not in correct json format; %v", err)
} }

View File

@ -9,7 +9,6 @@ import (
"github.com/rs/cors" "github.com/rs/cors"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
s "github.com/netbirdio/netbird/management/server" s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/http/middleware"

View File

@ -6,8 +6,10 @@ import (
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@ -61,10 +63,18 @@ func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w
groupsInfo := toGroupsInfo(account.Groups, peer.ID) groupsInfo := toGroupsInfo(account.Groups, peer.ID)
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain()) validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
log.Errorf("failed to list appreoved peers: %v", err)
util.WriteError(fmt.Errorf("internal error"), w)
return
}
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain) accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
util.WriteJSONObject(w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers)) _, valid := validPeers[peer.ID]
util.WriteJSONObject(w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
} }
func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
@ -75,11 +85,18 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe
return return
} }
update := &nbpeer.Peer{ID: peerID, SSHEnabled: req.SshEnabled, Name: req.Name, update := &nbpeer.Peer{
LoginExpirationEnabled: req.LoginExpirationEnabled} ID: peerID,
SSHEnabled: req.SshEnabled,
Name: req.Name,
LoginExpirationEnabled: req.LoginExpirationEnabled,
}
if req.ApprovalRequired != nil { if req.ApprovalRequired != nil {
update.Status = &nbpeer.PeerStatus{RequiresApproval: *req.ApprovalRequired} // todo: looks like that we reset all status property, is it right?
update.Status = &nbpeer.PeerStatus{
RequiresApproval: *req.ApprovalRequired,
}
} }
peer, err := h.accountManager.UpdatePeer(account.Id, user.Id, update) peer, err := h.accountManager.UpdatePeer(account.Id, user.Id, update)
@ -91,15 +108,24 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain()) validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
log.Errorf("failed to list appreoved peers: %v", err)
util.WriteError(fmt.Errorf("internal error"), w)
return
}
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain) accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
util.WriteJSONObject(w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers)) _, valid := validPeers[peer.ID]
util.WriteJSONObject(w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid))
} }
func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w http.ResponseWriter) { func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w http.ResponseWriter) {
err := h.accountManager.DeletePeer(accountID, peerID, userID) err := h.accountManager.DeletePeer(accountID, peerID, userID)
if err != nil { if err != nil {
log.Errorf("failed to delete peer: %v", err)
util.WriteError(err, w) util.WriteError(err, w)
return return
} }
@ -138,46 +164,68 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
// GetAllPeers returns a list of all peers associated with a provided account // GetAllPeers returns a list of all peers associated with a provided account
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
switch r.Method { if r.Method != http.MethodGet {
case http.MethodGet:
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
peers, err := h.accountManager.GetPeers(account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
return
}
dnsDomain := h.accountManager.GetDNSDomain()
respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(err, w)
return
}
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
accessiblePeerNumbers := h.accessiblePeersNumber(account, peer.ID)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers))
}
util.WriteJSONObject(w, respBody)
return
default:
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w) util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
return
} }
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
peers, err := h.accountManager.GetPeers(account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
return
}
dnsDomain := h.accountManager.GetDNSDomain()
respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(err, w)
return
}
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
accessiblePeerNumbers, _ := h.accessiblePeersNumber(account, peer.ID)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers))
}
validPeersMap, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
log.Errorf("failed to list appreoved peers: %v", err)
util.WriteError(fmt.Errorf("internal error"), w)
return
}
h.setApprovalRequiredFlag(respBody, validPeersMap)
util.WriteJSONObject(w, respBody)
} }
func (h *PeersHandler) accessiblePeersNumber(account *server.Account, peerID string) int { func (h *PeersHandler) accessiblePeersNumber(account *server.Account, peerID string) (int, error) {
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain()) validatedPeersMap, err := h.accountManager.GetValidatedPeers(account)
return len(netMap.Peers) + len(netMap.OfflinePeers) if err != nil {
return 0, err
}
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validatedPeersMap)
return len(netMap.Peers) + len(netMap.OfflinePeers), nil
}
func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
for _, peer := range respBody {
_, ok := approvedPeersMap[peer.Id]
if !ok {
peer.ApprovalRequired = true
}
}
} }
func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer { func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer {
@ -206,7 +254,7 @@ func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.Access
return accessiblePeers return accessiblePeers
} }
func toGroupsInfo(groups map[string]*server.Group, peerID string) []api.GroupMinimum { func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum var groupsInfo []api.GroupMinimum
groupsChecked := make(map[string]struct{}) groupsChecked := make(map[string]struct{})
for _, group := range groups { for _, group := range groups {
@ -230,7 +278,7 @@ func toGroupsInfo(groups map[string]*server.Group, peerID string) []api.GroupMin
return groupsInfo return groupsInfo
} }
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer) *api.Peer { func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer, approved bool) *api.Peer {
osVersion := peer.Meta.OSVersion osVersion := peer.Meta.OSVersion
if osVersion == "" { if osVersion == "" {
osVersion = peer.Meta.Core osVersion = peer.Meta.Core
@ -257,7 +305,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
LastLogin: peer.LastLogin, LastLogin: peer.LastLogin,
LoginExpired: peer.Status.LoginExpired, LoginExpired: peer.Status.LoginExpired,
AccessiblePeers: accessiblePeer, AccessiblePeers: accessiblePeer,
ApprovalRequired: &peer.Status.RequiresApproval, ApprovalRequired: !approved,
CountryCode: peer.Location.CountryCode, CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName, CityName: peer.Location.CityName,
} }
@ -290,7 +338,6 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
LastLogin: peer.LastLogin, LastLogin: peer.LastLogin,
LoginExpired: peer.Status.LoginExpired, LoginExpired: peer.Status.LoginExpired,
AccessiblePeersCount: accessiblePeersCount, AccessiblePeersCount: accessiblePeersCount,
ApprovalRequired: &peer.Status.RequiresApproval,
CountryCode: peer.Location.CountryCode, CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName, CityName: peer.Location.CityName,
} }

View File

@ -9,6 +9,7 @@ import (
"strings" "strings"
"testing" "testing"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@ -51,7 +52,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
Policies: []*server.Policy{ Policies: []*server.Policy{
{ID: "id-existed"}, {ID: "id-existed"},
}, },
Groups: map[string]*server.Group{ Groups: map[string]*nbgroup.Group{
"F": {ID: "F"}, "F": {ID: "F"},
"G": {ID: "G"}, "G": {ID: "G"},
}, },

View File

@ -13,13 +13,12 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
) )
const ( const (
@ -44,7 +43,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
SetupKeys: map[string]*server.SetupKey{ SetupKeys: map[string]*server.SetupKey{
defaultKey.Key: defaultKey, defaultKey.Key: defaultKey,
}, },
Groups: map[string]*server.Group{ Groups: map[string]*nbgroup.Group{
"group-1": {ID: "group-1", Peers: []string{"A", "B"}}, "group-1": {ID: "group-1", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}, "id-all": {ID: "id-all", Name: "All"},
}, },

View File

@ -99,6 +99,8 @@ func WriteError(err error, w http.ResponseWriter) {
httpStatus = http.StatusUnprocessableEntity httpStatus = http.StatusUnprocessableEntity
case status.Unauthorized: case status.Unauthorized:
httpStatus = http.StatusUnauthorized httpStatus = http.StatusUnauthorized
case status.BadRequest:
httpStatus = http.StatusBadRequest
default: default:
} }
msg = strings.ToLower(err.Error()) msg = strings.ToLower(err.Error())

View File

@ -0,0 +1,80 @@
package server
import (
"errors"
"github.com/google/martian/v3/log"
"github.com/netbirdio/netbird/management/server/account"
)
// UpdateIntegratedValidatorGroups updates the integrated validator groups for a specified account.
// It retrieves the account associated with the provided userID, then updates the integrated validator groups
// with the provided list of group ids. The updated account is then saved.
//
// Parameters:
// - accountID: The ID of the account for which integrated validator groups are to be updated.
// - userID: The ID of the user whose account is being updated.
// - groups: A slice of strings representing the ids of integrated validator groups to be updated.
//
// Returns:
// - error: An error if any occurred during the process, otherwise returns nil
func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error {
ok, err := am.GroupValidation(accountID, groups)
if err != nil {
log.Debugf("error validating groups: %s", err.Error())
return err
}
if !ok {
log.Debugf("invalid groups")
return errors.New("invalid groups")
}
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
a, err := am.Store.GetAccountByUser(userID)
if err != nil {
return err
}
var extra *account.ExtraSettings
if a.Settings.Extra != nil {
extra = a.Settings.Extra
} else {
extra = &account.ExtraSettings{}
a.Settings.Extra = extra
}
extra.IntegratedValidatorGroups = groups
return am.Store.SaveAccount(a)
}
func (am *DefaultAccountManager) GroupValidation(accountId string, groups []string) (bool, error) {
if len(groups) == 0 {
return true, nil
}
accountsGroups, err := am.ListGroups(accountId)
if err != nil {
return false, err
}
for _, group := range groups {
var found bool
for _, accountGroup := range accountsGroups {
if accountGroup.ID == group {
found = true
break
}
}
if !found {
return false, nil
}
}
return true, nil
}
func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) {
return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra)
}

View File

@ -0,0 +1,19 @@
package integrated_validator
import (
"github.com/netbirdio/netbird/management/server/account"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
// IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface {
ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error)
PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool)
GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error)
PeerDeleted(accountID, peerID string) error
SetPeerInvalidationListener(fn func(accountID string))
Stop()
}

View File

@ -0,0 +1,23 @@
package integration_reference
import (
"fmt"
"strings"
)
// IntegrationReference holds the reference to a particular integration
type IntegrationReference struct {
ID int
IntegrationType string
}
func (ir IntegrationReference) String() string {
return fmt.Sprintf("%s:%d", ir.IntegrationType, ir.ID)
}
func (ir IntegrationReference) CacheKey(path ...string) string {
if len(path) == 0 {
return ir.String()
}
return fmt.Sprintf("%s:%s", ir.String(), strings.Join(path, ":"))
}

View File

@ -9,8 +9,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -19,6 +17,7 @@ import (
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@ -413,7 +412,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
peersUpdateManager := NewPeersUpdateManager(nil) peersUpdateManager := NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false) eventStore, nil, false, MocIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@ -10,24 +10,22 @@ import (
sync2 "sync" sync2 "sync"
"time" "time"
"github.com/netbirdio/netbird/management/server/activity"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/management/server"
pb "github.com/golang/protobuf/proto" //nolint pb "github.com/golang/protobuf/proto" //nolint
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/encryption"
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@ -448,6 +446,43 @@ var _ = Describe("Management service", func() {
}) })
}) })
type MocIntegratedValidator struct {
}
func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil
}
func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
return update, nil
}
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{})
for p := range peers {
validatedPeers[p] = struct{}{}
}
return validatedPeers, nil
}
func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
return peer
}
func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool) {
return false, false
}
func (MocIntegratedValidator) PeerDeleted(_, _ string) error {
return nil
}
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
}
func (MocIntegratedValidator) Stop() {}
func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse { func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse {
defer GinkgoRecover() defer GinkgoRecover()
@ -504,7 +539,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
peersUpdateManager := server.NewPeersUpdateManager(nil) peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false) eventStore, nil, false, MocIntegratedValidator{})
if err != nil { if err != nil {
log.Fatalf("failed creating a manager: %v", err) log.Fatalf("failed creating a manager: %v", err)
} }

View File

@ -5,6 +5,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@ -32,7 +33,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account {
UsedTimes: 1, UsedTimes: 1,
}, },
}, },
Groups: map[string]*server.Group{ Groups: map[string]*group.Group{
"1": {}, "1": {},
"2": {}, "2": {},
}, },
@ -117,7 +118,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account {
UsedTimes: 1, UsedTimes: 1,
}, },
}, },
Groups: map[string]*server.Group{ Groups: map[string]*group.Group{
"1": {}, "1": {},
"2": {}, "2": {},
}, },

View File

@ -10,6 +10,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
@ -31,12 +32,12 @@ type MockAccountManager struct {
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(peerKey string) (*server.Network, error) GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error) AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
GetGroupFunc func(accountID, groupID, userID string) (*server.Group, error) GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
GetAllGroupsFunc func(accountID, userID string) ([]*server.Group, error) GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
GetGroupByNameFunc func(accountID, groupName string) (*server.Group, error) GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
SaveGroupFunc func(accountID, userID string, group *server.Group) error SaveGroupFunc func(accountID, userID string, group *group.Group) error
DeleteGroupFunc func(accountID, userId, groupID string) error DeleteGroupFunc func(accountID, userId, groupID string) error
ListGroupsFunc func(accountID string) ([]*server.Group, error) ListGroupsFunc func(accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(accountID, groupID, peerID string) error GroupAddPeerFunc func(accountID, groupID, peerID string) error
GroupDeletePeerFunc func(accountID, groupID, peerID string) error GroupDeletePeerFunc func(accountID, groupID, peerID string) error
DeleteRuleFunc func(accountID, ruleID, userID string) error DeleteRuleFunc func(accountID, ruleID, userID string) error
@ -91,10 +92,20 @@ type MockAccountManager struct {
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error) ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager GetIdpManagerFunc func() idp.Manager
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
GroupValidationFunc func(accountId string, groups []string) (bool, error)
}
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) {
approvedPeers := make(map[string]struct{})
for id := range account.Peers {
approvedPeers[id] = struct{}{}
}
return approvedPeers, nil
} }
// GetGroup mock implementation of GetGroup from server.AccountManager interface // GetGroup mock implementation of GetGroup from server.AccountManager interface
func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*server.Group, error) { func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*group.Group, error) {
if am.GetGroupFunc != nil { if am.GetGroupFunc != nil {
return am.GetGroupFunc(accountId, groupID, userID) return am.GetGroupFunc(accountId, groupID, userID)
} }
@ -102,7 +113,7 @@ func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*serv
} }
// GetAllGroups mock implementation of GetAllGroups from server.AccountManager interface // GetAllGroups mock implementation of GetAllGroups from server.AccountManager interface
func (am *MockAccountManager) GetAllGroups(accountID, userID string) ([]*server.Group, error) { func (am *MockAccountManager) GetAllGroups(accountID, userID string) ([]*group.Group, error) {
if am.GetAllGroupsFunc != nil { if am.GetAllGroupsFunc != nil {
return am.GetAllGroupsFunc(accountID, userID) return am.GetAllGroupsFunc(accountID, userID)
} }
@ -261,7 +272,7 @@ func (am *MockAccountManager) AddPeer(
} }
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*server.Group, error) { func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*group.Group, error) {
if am.GetGroupFunc != nil { if am.GetGroupFunc != nil {
return am.GetGroupByNameFunc(accountID, groupName) return am.GetGroupByNameFunc(accountID, groupName)
} }
@ -269,7 +280,7 @@ func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*serv
} }
// SaveGroup mock implementation of SaveGroup from server.AccountManager interface // SaveGroup mock implementation of SaveGroup from server.AccountManager interface
func (am *MockAccountManager) SaveGroup(accountID, userID string, group *server.Group) error { func (am *MockAccountManager) SaveGroup(accountID, userID string, group *group.Group) error {
if am.SaveGroupFunc != nil { if am.SaveGroupFunc != nil {
return am.SaveGroupFunc(accountID, userID, group) return am.SaveGroupFunc(accountID, userID, group)
} }
@ -285,7 +296,7 @@ func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) err
} }
// ListGroups mock implementation of ListGroups from server.AccountManager interface // ListGroups mock implementation of ListGroups from server.AccountManager interface
func (am *MockAccountManager) ListGroups(accountID string) ([]*server.Group, error) { func (am *MockAccountManager) ListGroups(accountID string) ([]*group.Group, error) {
if am.ListGroupsFunc != nil { if am.ListGroupsFunc != nil {
return am.ListGroupsFunc(accountID) return am.ListGroupsFunc(accountID)
} }
@ -694,3 +705,19 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager {
} }
return nil return nil
} }
// UpdateIntegratedValidatedGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface
func (am *MockAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error {
if am.UpdateIntegratedValidatorGroupsFunc != nil {
return am.UpdateIntegratedValidatorGroupsFunc(accountID, userID, groups)
}
return status.Errorf(codes.Unimplemented, "method UpdateIntegratedValidatorGroups is not implemented")
}
// GroupValidation mocks GroupValidation of the AccountManager interface
func (am *MockAccountManager) GroupValidation(accountId string, groups []string) (bool, error) {
if am.GroupValidationFunc != nil {
return am.GroupValidationFunc(accountId, groups)
}
return false, status.Errorf(codes.Unimplemented, "method GroupValidation is not implemented")
}

View File

@ -10,6 +10,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
@ -261,7 +262,7 @@ func validateNSList(list []nbdns.NameServer) error {
return nil return nil
} }
func validateGroups(list []string, groups map[string]*Group) error { func validateGroups(list []string, groups map[string]*nbgroup.Group) error {
if len(list) == 0 { if len(list) == 0 {
return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty") return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty")
} }

View File

@ -8,6 +8,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
) )
@ -759,7 +760,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err return nil, err
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false) return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
} }
func createNSStore(t *testing.T) (Store, error) { func createNSStore(t *testing.T) (Store, error) {
@ -831,12 +832,12 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup
newGroup1 := &Group{ newGroup1 := &nbgroup.Group{
ID: group1ID, ID: group1ID,
Name: group1ID, Name: group1ID,
} }
newGroup2 := &Group{ newGroup2 := &nbgroup.Group{
ID: group2ID, ID: group2ID,
Name: group2ID, Name: group2ID,
} }

View File

@ -7,16 +7,12 @@ import (
"time" "time"
"github.com/rs/xid" "github.com/rs/xid"
"github.com/netbirdio/management-integrations/additions"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
) )
// PeerSync used as a data object between the gRPC API and AccountManager on Sync request. // PeerSync used as a data object between the gRPC API and AccountManager on Sync request.
@ -37,6 +33,8 @@ type PeerLogin struct {
UserID string UserID string
// SetupKey references to a server.SetupKey to log in. Can be empty when UserID is used or auth is not required. // SetupKey references to a server.SetupKey to log in. Can be empty when UserID is used or auth is not required.
SetupKey string SetupKey string
// ConnectionIP is the real IP of the peer
ConnectionIP net.IP
} }
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
@ -52,6 +50,10 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P
return nil, err return nil, err
} }
approvedPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
return nil, err
}
peers := make([]*nbpeer.Peer, 0) peers := make([]*nbpeer.Peer, 0)
peersMap := make(map[string]*nbpeer.Peer) peersMap := make(map[string]*nbpeer.Peer)
@ -71,7 +73,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P
// fetch all the peers that have access to the user's peers // fetch all the peers that have access to the user's peers
for _, peer := range peers { for _, peer := range peers {
aclPeers, _ := account.getPeerConnectionResources(peer.ID) aclPeers, _ := account.getPeerConnectionResources(peer.ID, approvedPeersMap)
for _, p := range aclPeers { for _, p := range aclPeers {
peersMap[p.ID] = p peersMap[p.ID] = p
} }
@ -167,7 +169,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb
return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID) return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID)
} }
update, err = additions.ValidatePeersUpdateRequest(update, peer, userID, accountID, am.eventStore, am.GetDNSDomain()) update, err = am.integratedPeerValidator.ValidatePeer(update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -244,6 +246,12 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string,
// the 2nd loop performs the actual modification // the 2nd loop performs the actual modification
for _, peer := range peers { for _, peer := range peers {
err := am.integratedPeerValidator.PeerDeleted(account.Id, peer.ID)
if err != nil {
return err
}
account.DeletePeer(peer.ID) account.DeletePeer(peer.ID)
am.peersUpdateManager.SendUpdate(peer.ID, am.peersUpdateManager.SendUpdate(peer.ID,
&UpdateMessage{ &UpdateMessage{
@ -304,7 +312,17 @@ func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, erro
if peer == nil { if peer == nil {
return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID) return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
} }
return account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil
groups := make(map[string][]string)
for groupID, group := range account.Groups {
groups[groupID] = group.Peers
}
validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra)
if err != nil {
return nil, err
}
return account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validatedPeers), nil
} }
// GetPeerNetwork returns the Network for a given peer // GetPeerNetwork returns the Network for a given peer
@ -433,10 +451,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
CreatedAt: registrationTime, CreatedAt: registrationTime,
LoginExpirationEnabled: addedByUser, LoginExpirationEnabled: addedByUser,
Ephemeral: ephemeral, Ephemeral: ephemeral,
} Location: peer.Location,
if account.Settings.Extra != nil {
newPeer = additions.PreparePeer(newPeer, account.Settings.Extra)
} }
// add peer to 'All' group // add peer to 'All' group
@ -467,6 +482,8 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
} }
} }
newPeer = am.integratedPeerValidator.PreparePeer(account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra)
if addedByUser { if addedByUser {
user, err := account.FindUser(userID) user, err := account.FindUser(userID)
if err != nil { if err != nil {
@ -492,7 +509,11 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
am.updateAccountPeers(account) am.updateAccountPeers(account)
networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain) approvedPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
return nil, nil, err
}
networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain, approvedPeersMap)
return newPeer, networkMap, nil return newPeer, networkMap, nil
} }
@ -529,23 +550,53 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network
if peerLoginExpired(peer, account) { if peerLoginExpired(peer, account) {
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
} }
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil
requiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if requiresApproval {
emptyMap := &NetworkMap{
Network: account.Network.Copy(),
}
return peer, emptyMap, nil
}
if isStatusChanged {
am.updateAccountPeers(account)
}
approvedPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
return nil, nil, err
}
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil
} }
// LoginPeer logs in or registers a peer. // LoginPeer logs in or registers a peer.
// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so.
func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) { func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) {
account, err := am.Store.GetAccountByPeerPubKey(login.WireGuardPubKey) account, err := am.Store.GetAccountByPeerPubKey(login.WireGuardPubKey)
if err != nil { if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
// we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet.
// Try registering it. // Try registering it.
return am.AddPeer(login.SetupKey, login.UserID, &nbpeer.Peer{ newPeer := &nbpeer.Peer{
Key: login.WireGuardPubKey, Key: login.WireGuardPubKey,
Meta: login.Meta, Meta: login.Meta,
SSHKey: login.SSHKey, SSHKey: login.SSHKey,
}) }
if am.geo != nil && login.ConnectionIP != nil {
location, err := am.geo.Lookup(login.ConnectionIP)
if err != nil {
log.Warnf("failed to get location for new peer realip: [%s]: %v", login.ConnectionIP.String(), err)
} else {
newPeer.Location.ConnectionIP = login.ConnectionIP
newPeer.Location.CountryCode = location.Country.ISOCode
newPeer.Location.CityName = location.City.Names.En
newPeer.Location.GeoNameID = location.City.GeonameID
}
}
return am.AddPeer(login.SetupKey, login.UserID, newPeer)
} }
log.Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err) log.Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err)
return nil, nil, status.Errorf(status.Internal, "failed while logging in peer") return nil, nil, status.Errorf(status.Internal, "failed while logging in peer")
@ -595,6 +646,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
am.StoreEvent(login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) am.StoreEvent(login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
} }
isRequiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
peer, updated := updatePeerMeta(peer, login.Meta, account) peer, updated := updatePeerMeta(peer, login.Meta, account)
if updated { if updated {
shouldStoreAccount = true shouldStoreAccount = true
@ -612,10 +664,23 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
} }
} }
if updateRemotePeers { if updateRemotePeers || isStatusChanged {
am.updateAccountPeers(account) am.updateAccountPeers(account)
} }
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil
if isRequiresApproval {
emptyMap := &NetworkMap{
Network: account.Network.Copy(),
}
return peer, emptyMap, nil
}
approvedPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
return nil, nil, err
}
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil
} }
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error { func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
@ -764,8 +829,13 @@ func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbp
return nil, err return nil, err
} }
approvedPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
return nil, err
}
for _, p := range userPeers { for _, p := range userPeers {
aclPeers, _ := account.getPeerConnectionResources(p.ID) aclPeers, _ := account.getPeerConnectionResources(p.ID, approvedPeersMap)
for _, aclPeer := range aclPeers { for _, aclPeer := range aclPeers {
if aclPeer.ID == peerID { if aclPeer.ID == peerID {
return peer, nil return peer, nil
@ -789,8 +859,13 @@ func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Acco
func (am *DefaultAccountManager) updateAccountPeers(account *Account) { func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
peers := account.GetPeers() peers := account.GetPeers()
approvedPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
log.Errorf("failed send out updates to peers, failed to validate peer: %v", err)
return
}
for _, peer := range peers { for _, peer := range peers {
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain) remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain()) update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain())
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update}) am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
} }

View File

@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
) )
@ -199,8 +200,8 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
return return
} }
var ( var (
group1 Group group1 nbgroup.Group
group2 Group group2 nbgroup.Group
policy Policy policy Policy
) )

View File

@ -5,11 +5,11 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/netbirdio/management-integrations/additions"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@ -211,7 +211,8 @@ type FirewallRule struct {
// getPeerConnectionResources for a given peer // getPeerConnectionResources for a given peer
// //
// This function returns the list of peers and firewall rules that are applicable to a given peer. // This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) getPeerConnectionResources(peerID string) ([]*nbpeer.Peer, []*FirewallRule) { func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator() generateResources, getAccumulatedResources := a.connResourcesGenerator()
for _, policy := range a.Policies { for _, policy := range a.Policies {
if !policy.Enabled { if !policy.Enabled {
@ -223,10 +224,8 @@ func (a *Account) getPeerConnectionResources(peerID string) ([]*nbpeer.Peer, []*
continue continue
} }
sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID, policy.SourcePostureChecks) sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID, nil) destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID, nil, validatedPeersMap)
sourcePeers = additions.ValidatePeers(sourcePeers)
destinationPeers = additions.ValidatePeers(destinationPeers)
if rule.Bidirectional { if rule.Bidirectional {
if peerInSources { if peerInSources {
@ -264,7 +263,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in
all, err := a.GetGroupAll() all, err := a.GetGroupAll()
if err != nil { if err != nil {
log.Errorf("failed to get group all: %v", err) log.Errorf("failed to get group all: %v", err)
all = &Group{} all = &nbgroup.Group{}
} }
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
@ -491,7 +490,7 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
// //
// Important: Posture checks are applicable only to source group peers, // Important: Posture checks are applicable only to source group peers,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs // for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func getAllPeersFromGroups(account *Account, groups []string, peerID string, sourcePostureChecksIDs []string) ([]*nbpeer.Peer, bool) { func getAllPeersFromGroups(account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
peerInGroups := false peerInGroups := false
filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
for _, g := range groups { for _, g := range groups {
@ -512,6 +511,10 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string, sou
continue continue
} }
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
if peer.ID == peerID { if peer.ID == peerID {
peerInGroups = true peerInGroups = true
continue continue

View File

@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
) )
@ -56,7 +57,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
Status: &nbpeer.PeerStatus{}, Status: &nbpeer.PeerStatus{},
}, },
}, },
Groups: map[string]*Group{ Groups: map[string]*nbgroup.Group{
"GroupAll": { "GroupAll": {
ID: "GroupAll", ID: "GroupAll",
Name: "All", Name: "All",
@ -135,16 +136,21 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
}, },
} }
validatedPeers := make(map[string]struct{})
for p := range account.Peers {
validatedPeers[p] = struct{}{}
}
t.Run("check that all peers get map", func(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) {
for _, p := range account.Peers { for _, p := range account.Peers {
peers, firewallRules := account.getPeerConnectionResources(p.ID) peers, firewallRules := account.getPeerConnectionResources(p.ID, validatedPeers)
assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present")
assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present")
} }
}) })
t.Run("check first peer map details", func(t *testing.T) { t.Run("check first peer map details", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources("peerB") peers, firewallRules := account.getPeerConnectionResources("peerB", validatedPeers)
assert.Len(t, peers, 7) assert.Len(t, peers, 7)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
@ -299,7 +305,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
Status: &nbpeer.PeerStatus{}, Status: &nbpeer.PeerStatus{},
}, },
}, },
Groups: map[string]*Group{ Groups: map[string]*nbgroup.Group{
"GroupAll": { "GroupAll": {
ID: "GroupAll", ID: "GroupAll",
Name: "All", Name: "All",
@ -374,8 +380,13 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}, },
} }
approvedPeers := make(map[string]struct{})
for p := range account.Peers {
approvedPeers[p] = struct{}{}
}
t.Run("check first peer map", func(t *testing.T) { t.Run("check first peer map", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources("peerB") peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers)
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*FirewallRule{ epectedFirewallRules := []*FirewallRule{
@ -403,7 +414,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}) })
t.Run("check second peer map", func(t *testing.T) { t.Run("check second peer map", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources("peerC") peers, firewallRules := account.getPeerConnectionResources("peerC", approvedPeers)
assert.Contains(t, peers, account.Peers["peerB"]) assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*FirewallRule{ epectedFirewallRules := []*FirewallRule{
@ -433,7 +444,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
account.Policies[1].Rules[0].Bidirectional = false account.Policies[1].Rules[0].Bidirectional = false
t.Run("check first peer map directional only", func(t *testing.T) { t.Run("check first peer map directional only", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources("peerB") peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers)
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*FirewallRule{ epectedFirewallRules := []*FirewallRule{
@ -454,7 +465,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}) })
t.Run("check second peer map directional only", func(t *testing.T) { t.Run("check second peer map directional only", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources("peerC") peers, firewallRules := account.getPeerConnectionResources("peerC", approvedPeers)
assert.Contains(t, peers, account.Peers["peerB"]) assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*FirewallRule{ epectedFirewallRules := []*FirewallRule{
@ -569,7 +580,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
}, },
}, },
}, },
Groups: map[string]*Group{ Groups: map[string]*nbgroup.Group{
"GroupAll": { "GroupAll": {
ID: "GroupAll", ID: "GroupAll",
Name: "All", Name: "All",
@ -644,10 +655,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
}, },
}) })
approvedPeers := make(map[string]struct{})
for p := range account.Peers {
approvedPeers[p] = struct{}{}
}
t.Run("verify peer's network map with default group peer list", func(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
// will establish a connection with all source peers satisfying the NB posture check. // will establish a connection with all source peers satisfying the NB posture check.
peers, firewallRules := account.getPeerConnectionResources("peerB") peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers)
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@ -657,7 +672,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerC satisfy the NB posture check, should establish connection to all destination group peer's // peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections // We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.getPeerConnectionResources("peerC") peers, firewallRules = account.getPeerConnectionResources("peerC", approvedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 1) assert.Len(t, firewallRules, 1)
expectedFirewallRules := []*FirewallRule{ expectedFirewallRules := []*FirewallRule{
@ -673,7 +688,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.getPeerConnectionResources("peerE") peers, firewallRules = account.getPeerConnectionResources("peerE", approvedPeers)
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@ -683,7 +698,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.getPeerConnectionResources("peerI") peers, firewallRules = account.getPeerConnectionResources("peerI", approvedPeers)
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@ -698,19 +713,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group // no connection should be established to any peer of destination group
peers, firewallRules := account.getPeerConnectionResources("peerB") peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers)
assert.Len(t, peers, 0) assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0) assert.Len(t, firewallRules, 0)
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group // no connection should be established to any peer of destination group
peers, firewallRules = account.getPeerConnectionResources("peerI") peers, firewallRules = account.getPeerConnectionResources("peerI", approvedPeers)
assert.Len(t, peers, 0) assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0) assert.Len(t, firewallRules, 0)
// peerC satisfy the NB posture check, should establish connection to all destination group peer's // peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections // We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.getPeerConnectionResources("peerC") peers, firewallRules = account.getPeerConnectionResources("peerC", approvedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
@ -725,14 +740,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.getPeerConnectionResources("peerE") peers, firewallRules = account.getPeerConnectionResources("peerE", approvedPeers)
assert.Len(t, peers, 3) assert.Len(t, peers, 3)
assert.Len(t, firewallRules, 3) assert.Len(t, firewallRules, 3)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])
peers, firewallRules = account.getPeerConnectionResources("peerA") peers, firewallRules = account.getPeerConnectionResources("peerA", approvedPeers)
assert.Len(t, peers, 5) assert.Len(t, peers, 5)
// assert peers from Group Swarm // assert peers from Group Swarm
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])

View File

@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -858,7 +859,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
groups, err := am.ListGroups(account.Id) groups, err := am.ListGroups(account.Id)
require.NoError(t, err) require.NoError(t, err)
var groupHA1, groupHA2 *Group var groupHA1, groupHA2 *nbgroup.Group
for _, group := range groups { for _, group := range groups {
switch group.Name { switch group.Name {
case routeGroupHA1: case routeGroupHA1:
@ -967,7 +968,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.Len(t, peer2Routes.Routes, 1, "we should receive one route") require.Len(t, peer2Routes.Routes, 1, "we should receive one route")
require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group") require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group")
newGroup := &Group{ newGroup := &nbgroup.Group{
ID: xid.New().String(), ID: xid.New().String(),
Name: "peer1 group", Name: "peer1 group",
Peers: []string{peer1ID}, Peers: []string{peer1ID},
@ -1014,7 +1015,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err return nil, err
} }
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false) return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
} }
func createRouterStore(t *testing.T) (Store, error) { func createRouterStore(t *testing.T) (Store, error) {
@ -1195,7 +1196,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
return nil, err return nil, err
} }
newGroup := []*Group{ newGroup := []*nbgroup.Group{
{ {
ID: routeGroup1, ID: routeGroup1,
Name: routeGroup1, Name: routeGroup1,

View File

@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
) )
func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
@ -24,7 +25,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(account.Id, userID, &Group{ err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{
ID: "group_1", ID: "group_1",
Name: "group_name_1", Name: "group_name_1",
Peers: []string{}, Peers: []string{},
@ -82,7 +83,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(account.Id, userID, &Group{ err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{
ID: "group_1", ID: "group_1",
Name: "group_name_1", Name: "group_name_1",
Peers: []string{}, Peers: []string{},
@ -91,7 +92,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(account.Id, userID, &Group{ err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{
ID: "group_2", ID: "group_2",
Name: "group_name_2", Name: "group_name_2",
Peers: []string{}, Peers: []string{},
@ -178,7 +179,7 @@ func TestGetSetupKeys(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(account.Id, userID, &Group{ err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{
ID: "group_1", ID: "group_1",
Name: "group_name_1", Name: "group_name_1",
Peers: []string{}, Peers: []string{},
@ -187,7 +188,7 @@ func TestGetSetupKeys(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
err = manager.SaveGroup(account.Id, userID, &Group{ err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{
ID: "group_2", ID: "group_2",
Name: "group_name_2", Name: "group_name_2",
Peers: []string{}, Peers: []string{},

View File

@ -17,6 +17,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@ -64,7 +65,7 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore,
sql.SetMaxOpenConns(conns) // TODO: make it configurable sql.SetMaxOpenConns(conns) // TODO: make it configurable
err = db.AutoMigrate( err = db.AutoMigrate(
&SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &Group{}, &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &nbgroup.Group{},
&Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
) )
@ -99,17 +100,17 @@ func NewSqliteStoreFromFileStore(filestore *FileStore, dataDir string, metrics t
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock // AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *SqliteStore) AcquireGlobalLock() (unlock func()) { func (s *SqliteStore) AcquireGlobalLock() (unlock func()) {
log.Debugf("acquiring global lock") log.Tracef("acquiring global lock")
start := time.Now() start := time.Now()
s.globalAccountLock.Lock() s.globalAccountLock.Lock()
unlock = func() { unlock = func() {
s.globalAccountLock.Unlock() s.globalAccountLock.Unlock()
log.Debugf("released global lock in %v", time.Since(start)) log.Tracef("released global lock in %v", time.Since(start))
} }
took := time.Since(start) took := time.Since(start)
log.Debugf("took %v to acquire global lock", took) log.Tracef("took %v to acquire global lock", took)
if s.metrics != nil { if s.metrics != nil {
s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took) s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took)
} }
@ -118,7 +119,7 @@ func (s *SqliteStore) AcquireGlobalLock() (unlock func()) {
} }
func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) { func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) {
log.Debugf("acquiring lock for account %s", accountID) log.Tracef("acquiring lock for account %s", accountID)
start := time.Now() start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
@ -127,7 +128,7 @@ func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) {
unlock = func() { unlock = func() {
mtx.Unlock() mtx.Unlock()
log.Debugf("released lock for account %s in %v", accountID, time.Since(start)) log.Tracef("released lock for account %s in %v", accountID, time.Since(start))
} }
return unlock return unlock
@ -434,7 +435,7 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
} }
account.UsersG = nil account.UsersG = nil
account.Groups = make(map[string]*Group, len(account.GroupsG)) account.Groups = make(map[string]*nbgroup.Group, len(account.GroupsG))
for _, group := range account.GroupsG { for _, group := range account.GroupsG {
account.Groups[group.ID] = group.Copy() account.Groups[group.ID] = group.Copy()
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@ -49,23 +50,6 @@ type UserStatus string
// UserRole is the role of a User // UserRole is the role of a User
type UserRole string type UserRole string
// IntegrationReference holds the reference to a particular integration
type IntegrationReference struct {
ID int
IntegrationType string
}
func (ir IntegrationReference) String() string {
return fmt.Sprintf("%s:%d", ir.IntegrationType, ir.ID)
}
func (ir IntegrationReference) CacheKey(path ...string) string {
if len(path) == 0 {
return ir.String()
}
return fmt.Sprintf("%s:%s", ir.String(), strings.Join(path, ":"))
}
// User represents a user of the system // User represents a user of the system
type User struct { type User struct {
Id string `gorm:"primaryKey"` Id string `gorm:"primaryKey"`
@ -91,7 +75,7 @@ type User struct {
// Issued of the user // Issued of the user
Issued string `gorm:"default:api"` Issued string `gorm:"default:api"`
IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
} }
// IsBlocked returns true if the user is blocked, false otherwise // IsBlocked returns true if the user is blocked, false otherwise

View File

@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
) )
@ -276,7 +277,7 @@ func TestUser_Copy(t *testing.T) {
LastLogin: time.Now().UTC(), LastLogin: time.Now().UTC(),
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Issued: "test", Issued: "test",
IntegrationReference: IntegrationReference{ IntegrationReference: integration_reference.IntegrationReference{
ID: 0, ID: 0,
IntegrationType: "test", IntegrationType: "test",
}, },
@ -603,8 +604,9 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
} }
am := DefaultAccountManager{ am := DefaultAccountManager{
Store: store, Store: store,
eventStore: &activity.InMemoryEventStore{}, eventStore: &activity.InMemoryEventStore{},
integratedPeerValidator: MocIntegratedValidator{},
} }
testCases := []struct { testCases := []struct {
@ -793,7 +795,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
Id: "externalUser", Id: "externalUser",
Role: UserRoleUser, Role: UserRoleUser,
Issued: UserIssuedIntegration, Issued: UserIssuedIntegration,
IntegrationReference: IntegrationReference{ IntegrationReference: integration_reference.IntegrationReference{
ID: 1, ID: 1,
IntegrationType: "external", IntegrationType: "external",
}, },