mirror of
https://github.com/netbirdio/netbird.git
synced 2025-02-08 22:39:55 +01:00
Add rules for ACL (#306)
Add rules HTTP endpoint for frontend - CRUD operations. Add Default rule - allow all. Send network map to peers based on rules.
This commit is contained in:
parent
11a3863c28
commit
3ce3ccc39a
@ -68,7 +68,10 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
|||||||
}
|
}
|
||||||
|
|
||||||
peersUpdateManager := mgmt.NewPeersUpdateManager()
|
peersUpdateManager := mgmt.NewPeersUpdateManager()
|
||||||
accountManager := mgmt.NewManager(store, peersUpdateManager, nil)
|
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager)
|
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -455,7 +455,10 @@ func startManagement(port int, dataDir string) (*grpc.Server, error) {
|
|||||||
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
||||||
}
|
}
|
||||||
peersUpdateManager := server.NewPeersUpdateManager()
|
peersUpdateManager := server.NewPeersUpdateManager()
|
||||||
accountManager := server.NewManager(store, peersUpdateManager, nil)
|
accountManager, err := server.BuildManager(store, peersUpdateManager, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager)
|
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -28,7 +28,6 @@ import (
|
|||||||
const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||||
|
|
||||||
func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||||
|
|
||||||
level, _ := log.ParseLevel("debug")
|
level, _ := log.ParseLevel("debug")
|
||||||
log.SetLevel(level)
|
log.SetLevel(level)
|
||||||
|
|
||||||
@ -56,7 +55,10 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
peersUpdateManager := mgmt.NewPeersUpdateManager()
|
peersUpdateManager := mgmt.NewPeersUpdateManager()
|
||||||
accountManager := mgmt.NewManager(store, peersUpdateManager, nil)
|
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager)
|
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -256,6 +258,7 @@ func TestClient_Sync(t *testing.T) {
|
|||||||
}
|
}
|
||||||
if len(resp.GetRemotePeers()) != 1 {
|
if len(resp.GetRemotePeers()) != 1 {
|
||||||
t.Errorf("expecting RemotePeers size %d got %d", 1, len(resp.GetRemotePeers()))
|
t.Errorf("expecting RemotePeers size %d got %d", 1, len(resp.GetRemotePeers()))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if resp.GetRemotePeersIsEmpty() == true {
|
if resp.GetRemotePeersIsEmpty() == true {
|
||||||
t.Error("expecting RemotePeers property to be false, got true")
|
t.Error("expecting RemotePeers property to be false, got true")
|
||||||
@ -295,8 +298,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
|
|||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
|
||||||
mgmtMockServer.LoginFunc =
|
mgmtMockServer.LoginFunc = func(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
func(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
|
||||||
peerKey, err := wgtypes.ParseKey(msg.GetWgPubKey())
|
peerKey, err := wgtypes.ParseKey(msg.GetWgPubKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("error while parsing peer's Wireguard public key %s on Sync request.", msg.WgPubKey)
|
log.Warnf("error while parsing peer's Wireguard public key %s on Sync request.", msg.WgPubKey)
|
||||||
@ -370,9 +372,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
|
|||||||
ProviderConfig: &proto.ProviderConfig{ClientID: "client"},
|
ProviderConfig: &proto.ProviderConfig{ClientID: "client"},
|
||||||
}
|
}
|
||||||
|
|
||||||
mgmtMockServer.GetDeviceAuthorizationFlowFunc =
|
mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
|
||||||
|
|
||||||
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
|
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -108,7 +108,10 @@ var (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
accountManager := server.NewManager(store, peersUpdateManager, idpManager)
|
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("failed build default manager: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
var opts []grpc.ServerOption
|
var opts []grpc.ServerOption
|
||||||
|
|
||||||
@ -309,5 +312,4 @@ func init() {
|
|||||||
mgmtCmd.Flags().StringVar(&certFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
|
mgmtCmd.Flags().StringVar(&certFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
|
||||||
mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
|
mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
|
||||||
rootCmd.MarkFlagRequired("config") //nolint
|
rootCmd.MarkFlagRequired("config") //nolint
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -24,7 +24,12 @@ const (
|
|||||||
type AccountManager interface {
|
type AccountManager interface {
|
||||||
GetOrCreateAccountByUser(userId, domain string) (*Account, error)
|
GetOrCreateAccountByUser(userId, domain string) (*Account, error)
|
||||||
GetAccountByUser(userId string) (*Account, error)
|
GetAccountByUser(userId string) (*Account, error)
|
||||||
AddSetupKey(accountId string, keyName string, keyType SetupKeyType, expiresIn *util.Duration) (*SetupKey, error)
|
AddSetupKey(
|
||||||
|
accountId string,
|
||||||
|
keyName string,
|
||||||
|
keyType SetupKeyType,
|
||||||
|
expiresIn *util.Duration,
|
||||||
|
) (*SetupKey, error)
|
||||||
RevokeSetupKey(accountId string, keyId string) (*SetupKey, error)
|
RevokeSetupKey(accountId string, keyId string) (*SetupKey, error)
|
||||||
RenameSetupKey(accountId string, keyId string, newName string) (*SetupKey, error)
|
RenameSetupKey(accountId string, keyId string, newName string) (*SetupKey, error)
|
||||||
GetAccountById(accountId string) (*Account, error)
|
GetAccountById(accountId string) (*Account, error)
|
||||||
@ -47,6 +52,10 @@ type AccountManager interface {
|
|||||||
GroupAddPeer(accountId, groupID, peerKey string) error
|
GroupAddPeer(accountId, groupID, peerKey string) error
|
||||||
GroupDeletePeer(accountId, groupID, peerKey string) error
|
GroupDeletePeer(accountId, groupID, peerKey string) error
|
||||||
GroupListPeers(accountId, groupID string) ([]*Peer, error)
|
GroupListPeers(accountId, groupID string) ([]*Peer, error)
|
||||||
|
GetRule(accountId, ruleID string) (*Rule, error)
|
||||||
|
SaveRule(accountID string, rule *Rule) error
|
||||||
|
DeleteRule(accountId, ruleID string) error
|
||||||
|
ListRules(accountId string) ([]*Rule, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaultAccountManager struct {
|
type DefaultAccountManager struct {
|
||||||
@ -70,6 +79,7 @@ type Account struct {
|
|||||||
Peers map[string]*Peer
|
Peers map[string]*Peer
|
||||||
Users map[string]*User
|
Users map[string]*User
|
||||||
Groups map[string]*Group
|
Groups map[string]*Group
|
||||||
|
Rules map[string]*Rule
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserInfo struct {
|
type UserInfo struct {
|
||||||
@ -101,6 +111,16 @@ func (a *Account) Copy() *Account {
|
|||||||
setupKeys[id] = key.Copy()
|
setupKeys[id] = key.Copy()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groups := map[string]*Group{}
|
||||||
|
for id, group := range a.Groups {
|
||||||
|
groups[id] = group.Copy()
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := map[string]*Rule{}
|
||||||
|
for id, rule := range a.Rules {
|
||||||
|
rules[id] = rule.Copy()
|
||||||
|
}
|
||||||
|
|
||||||
return &Account{
|
return &Account{
|
||||||
Id: a.Id,
|
Id: a.Id,
|
||||||
CreatedBy: a.CreatedBy,
|
CreatedBy: a.CreatedBy,
|
||||||
@ -108,17 +128,43 @@ func (a *Account) Copy() *Account {
|
|||||||
Network: a.Network.Copy(),
|
Network: a.Network.Copy(),
|
||||||
Peers: peers,
|
Peers: peers,
|
||||||
Users: users,
|
Users: users,
|
||||||
|
Groups: groups,
|
||||||
|
Rules: rules,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager creates a new DefaultAccountManager with a provided Store
|
func (a *Account) GetGroupAll() (*Group, error) {
|
||||||
func NewManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager) *DefaultAccountManager {
|
for _, g := range a.Groups {
|
||||||
return &DefaultAccountManager{
|
if g.Name == "All" {
|
||||||
|
return g, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("no group ALL found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||||
|
func BuildManager(
|
||||||
|
store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
|
||||||
|
) (*DefaultAccountManager, error) {
|
||||||
|
dam := &DefaultAccountManager{
|
||||||
Store: store,
|
Store: store,
|
||||||
mux: sync.Mutex{},
|
mux: sync.Mutex{},
|
||||||
peersUpdateManager: peersUpdateManager,
|
peersUpdateManager: peersUpdateManager,
|
||||||
idpManager: idpManager,
|
idpManager: idpManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if account has not default account
|
||||||
|
// we build 'all' group and add all peers into it
|
||||||
|
// also we create default rule with source an destination
|
||||||
|
// groups 'all'
|
||||||
|
for _, account := range store.GetAllAccounts() {
|
||||||
|
dam.addAllGroup(account)
|
||||||
|
if err := store.SaveAccount(account); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return dam, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddSetupKey generates a new setup key with a given name and type, and adds it to the specified account
|
// AddSetupKey generates a new setup key with a given name and type, and adds it to the specified account
|
||||||
@ -223,7 +269,9 @@ func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, err
|
|||||||
|
|
||||||
// GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and
|
// GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and
|
||||||
// user id doesn't have an account associated with it, one account is created
|
// user id doesn't have an account associated with it, one account is created
|
||||||
func (am *DefaultAccountManager) GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error) {
|
func (am *DefaultAccountManager) GetAccountByUserOrAccountId(
|
||||||
|
userId, accountId, domain string,
|
||||||
|
) (*Account, error) {
|
||||||
if accountId != "" {
|
if accountId != "" {
|
||||||
return am.GetAccountById(accountId)
|
return am.GetAccountById(accountId)
|
||||||
} else if userId != "" {
|
} else if userId != "" {
|
||||||
@ -490,6 +538,8 @@ func (am *DefaultAccountManager) AddAccount(accountId, userId, domain string) (*
|
|||||||
func (am *DefaultAccountManager) createAccount(accountId, userId, domain string) (*Account, error) {
|
func (am *DefaultAccountManager) createAccount(accountId, userId, domain string) (*Account, error) {
|
||||||
account := newAccountWithId(accountId, userId, domain)
|
account := newAccountWithId(accountId, userId, domain)
|
||||||
|
|
||||||
|
am.addAllGroup(account)
|
||||||
|
|
||||||
err := am.Store.SaveAccount(account)
|
err := am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "failed creating account")
|
return nil, status.Errorf(codes.Internal, "failed creating account")
|
||||||
@ -498,6 +548,28 @@ func (am *DefaultAccountManager) createAccount(accountId, userId, domain string)
|
|||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addAllGroup to account object it it doesn't exists
|
||||||
|
func (am *DefaultAccountManager) addAllGroup(account *Account) {
|
||||||
|
if len(account.Groups) == 0 {
|
||||||
|
allGroup := &Group{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
Name: "All",
|
||||||
|
}
|
||||||
|
for _, peer := range account.Peers {
|
||||||
|
allGroup.Peers = append(allGroup.Peers, peer.Key)
|
||||||
|
}
|
||||||
|
account.Groups = map[string]*Group{allGroup.ID: allGroup}
|
||||||
|
|
||||||
|
defaultRule := &Rule{
|
||||||
|
ID: xid.New().String(),
|
||||||
|
Name: "Default",
|
||||||
|
Source: []string{allGroup.ID},
|
||||||
|
Destination: []string{allGroup.ID},
|
||||||
|
}
|
||||||
|
account.Rules = map[string]*Rule{defaultRule.ID: defaultRule}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
||||||
func newAccountWithId(accountId, userId, domain string) *Account {
|
func newAccountWithId(accountId, userId, domain string) *Account {
|
||||||
log.Debugf("creating new account")
|
log.Debugf("creating new account")
|
||||||
|
@ -37,7 +37,6 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
|
func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
|
||||||
|
|
||||||
type initUserParams jwtclaims.AuthorizationClaims
|
type initUserParams jwtclaims.AuthorizationClaims
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
@ -165,7 +164,6 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6} {
|
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6} {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
|
||||||
manager, err := createManager(t)
|
manager, err := createManager(t)
|
||||||
require.NoError(t, err, "unable to create account manager")
|
require.NoError(t, err, "unable to create account manager")
|
||||||
|
|
||||||
@ -346,7 +344,6 @@ func TestAccountManager_AccountExists(t *testing.T) {
|
|||||||
if !*exists {
|
if !*exists {
|
||||||
t.Errorf("expected account to exist after creation, got false")
|
t.Errorf("expected account to exist after creation, got false")
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountManager_GetAccount(t *testing.T) {
|
func TestAccountManager_GetAccount(t *testing.T) {
|
||||||
@ -385,7 +382,6 @@ func TestAccountManager_GetAccount(t *testing.T) {
|
|||||||
t.Errorf("expected account to have setup key %s, not found", key.Key)
|
t.Errorf("expected account to have setup key %s, not found", key.Key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountManager_AddPeer(t *testing.T) {
|
func TestAccountManager_AddPeer(t *testing.T) {
|
||||||
@ -457,7 +453,6 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
|||||||
if account.Network.CurrentSerial() != 1 {
|
if account.Network.CurrentSerial() != 1 {
|
||||||
t.Errorf("expecting Network Serial=%d to be incremented by 1 and be equal to %d when adding new peer to account", serial, account.Network.CurrentSerial())
|
t.Errorf("expecting Network Serial=%d to be incremented by 1 and be equal to %d when adding new peer to account", serial, account.Network.CurrentSerial())
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||||
@ -521,7 +516,6 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
|||||||
if account.Network.CurrentSerial() != 1 {
|
if account.Network.CurrentSerial() != 1 {
|
||||||
t.Errorf("expecting Network Serial=%d to be incremented by 1 and be equal to %d when adding new peer to account", serial, account.Network.CurrentSerial())
|
t.Errorf("expecting Network Serial=%d to be incremented by 1 and be equal to %d when adding new peer to account", serial, account.Network.CurrentSerial())
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccountManager_DeletePeer(t *testing.T) {
|
func TestAccountManager_DeletePeer(t *testing.T) {
|
||||||
@ -573,7 +567,6 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
|||||||
if account.Network.CurrentSerial() != 2 {
|
if account.Network.CurrentSerial() != 2 {
|
||||||
t.Errorf("expecting Network Serial=%d to be incremented and be equal to 2 after adding and deleteing a peer", account.Network.CurrentSerial())
|
t.Errorf("expecting Network Serial=%d to be incremented and be equal to 2 after adding and deleteing a peer", account.Network.CurrentSerial())
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetUsersFromAccount(t *testing.T) {
|
func TestGetUsersFromAccount(t *testing.T) {
|
||||||
@ -614,7 +607,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return NewManager(store, NewPeersUpdateManager(), nil), nil
|
return BuildManager(store, NewPeersUpdateManager(), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createStore(t *testing.T) (Store, error) {
|
func createStore(t *testing.T) (Store, error) {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@ -22,6 +23,8 @@ type FileStore struct {
|
|||||||
PeerKeyId2AccountId map[string]string `json:"-"`
|
PeerKeyId2AccountId map[string]string `json:"-"`
|
||||||
UserId2AccountId map[string]string `json:"-"`
|
UserId2AccountId map[string]string `json:"-"`
|
||||||
PrivateDomain2AccountId map[string]string `json:"-"`
|
PrivateDomain2AccountId map[string]string `json:"-"`
|
||||||
|
PeerKeyId2SrcRulesId map[string]map[string]struct{} `json:"-"`
|
||||||
|
PeerKeyId2DstRulesId map[string]map[string]struct{} `json:"-"`
|
||||||
|
|
||||||
// mutex to synchronise Store read/write operations
|
// mutex to synchronise Store read/write operations
|
||||||
mux sync.Mutex `json:"-"`
|
mux sync.Mutex `json:"-"`
|
||||||
@ -47,6 +50,8 @@ func restore(file string) (*FileStore, error) {
|
|||||||
PeerKeyId2AccountId: make(map[string]string),
|
PeerKeyId2AccountId: make(map[string]string),
|
||||||
UserId2AccountId: make(map[string]string),
|
UserId2AccountId: make(map[string]string),
|
||||||
PrivateDomain2AccountId: make(map[string]string),
|
PrivateDomain2AccountId: make(map[string]string),
|
||||||
|
PeerKeyId2SrcRulesId: make(map[string]map[string]struct{}),
|
||||||
|
PeerKeyId2DstRulesId: make(map[string]map[string]struct{}),
|
||||||
storeFile: file,
|
storeFile: file,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,10 +74,39 @@ func restore(file string) (*FileStore, error) {
|
|||||||
store.PeerKeyId2AccountId = make(map[string]string)
|
store.PeerKeyId2AccountId = make(map[string]string)
|
||||||
store.UserId2AccountId = make(map[string]string)
|
store.UserId2AccountId = make(map[string]string)
|
||||||
store.PrivateDomain2AccountId = make(map[string]string)
|
store.PrivateDomain2AccountId = make(map[string]string)
|
||||||
|
store.PeerKeyId2SrcRulesId = map[string]map[string]struct{}{}
|
||||||
|
store.PeerKeyId2DstRulesId = map[string]map[string]struct{}{}
|
||||||
|
|
||||||
for accountId, account := range store.Accounts {
|
for accountId, account := range store.Accounts {
|
||||||
for setupKeyId := range account.SetupKeys {
|
for setupKeyId := range account.SetupKeys {
|
||||||
store.SetupKeyId2AccountId[strings.ToUpper(setupKeyId)] = accountId
|
store.SetupKeyId2AccountId[strings.ToUpper(setupKeyId)] = accountId
|
||||||
}
|
}
|
||||||
|
for _, rule := range account.Rules {
|
||||||
|
for _, groupID := range rule.Source {
|
||||||
|
if group, ok := account.Groups[groupID]; ok {
|
||||||
|
for _, peerID := range group.Peers {
|
||||||
|
rules := store.PeerKeyId2SrcRulesId[peerID]
|
||||||
|
if rules == nil {
|
||||||
|
rules = map[string]struct{}{}
|
||||||
|
store.PeerKeyId2SrcRulesId[peerID] = rules
|
||||||
|
}
|
||||||
|
rules[rule.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, groupID := range rule.Destination {
|
||||||
|
if group, ok := account.Groups[groupID]; ok {
|
||||||
|
for _, peerID := range group.Peers {
|
||||||
|
rules := store.PeerKeyId2DstRulesId[peerID]
|
||||||
|
if rules == nil {
|
||||||
|
rules = map[string]struct{}{}
|
||||||
|
store.PeerKeyId2DstRulesId[peerID] = rules
|
||||||
|
}
|
||||||
|
rules[rule.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
for _, peer := range account.Peers {
|
for _, peer := range account.Peers {
|
||||||
store.PeerKeyId2AccountId[peer.Key] = accountId
|
store.PeerKeyId2AccountId[peer.Key] = accountId
|
||||||
}
|
}
|
||||||
@ -82,7 +116,8 @@ func restore(file string) (*FileStore, error) {
|
|||||||
for _, user := range account.Users {
|
for _, user := range account.Users {
|
||||||
store.UserId2AccountId[user.Id] = accountId
|
store.UserId2AccountId[user.Id] = accountId
|
||||||
}
|
}
|
||||||
if account.Domain != "" && account.DomainCategory == PrivateCategory && account.IsDomainPrimaryAccount {
|
if account.Domain != "" && account.DomainCategory == PrivateCategory &&
|
||||||
|
account.IsDomainPrimaryAccount {
|
||||||
store.PrivateDomain2AccountId[account.Domain] = accountId
|
store.PrivateDomain2AccountId[account.Domain] = accountId
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -106,6 +141,24 @@ func (s *FileStore) SavePeer(accountId string, peer *Peer) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if it is new peer, add it to default 'All' group
|
||||||
|
allGroup, err := account.GetGroupAll()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ind := -1
|
||||||
|
for i, pid := range allGroup.Peers {
|
||||||
|
if pid == peer.Key {
|
||||||
|
ind = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ind < 0 {
|
||||||
|
allGroup.Peers = append(allGroup.Peers, peer.Key)
|
||||||
|
}
|
||||||
|
|
||||||
account.Peers[peer.Key] = peer
|
account.Peers[peer.Key] = peer
|
||||||
return s.persist(s.storeFile)
|
return s.persist(s.storeFile)
|
||||||
}
|
}
|
||||||
@ -176,6 +229,29 @@ func (s *FileStore) SaveAccount(account *Account) error {
|
|||||||
s.PeerKeyId2AccountId[peer.Key] = account.Id
|
s.PeerKeyId2AccountId[peer.Key] = account.Id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, rule := range account.Rules {
|
||||||
|
for _, gid := range rule.Source {
|
||||||
|
for _, pid := range account.Groups[gid].Peers {
|
||||||
|
rules := s.PeerKeyId2SrcRulesId[pid]
|
||||||
|
if rules == nil {
|
||||||
|
rules = map[string]struct{}{}
|
||||||
|
s.PeerKeyId2SrcRulesId[pid] = rules
|
||||||
|
}
|
||||||
|
rules[rule.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, gid := range rule.Destination {
|
||||||
|
for _, pid := range account.Groups[gid].Peers {
|
||||||
|
rules := s.PeerKeyId2DstRulesId[pid]
|
||||||
|
if rules == nil {
|
||||||
|
rules = map[string]struct{}{}
|
||||||
|
s.PeerKeyId2DstRulesId[pid] = rules
|
||||||
|
}
|
||||||
|
rules[rule.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, user := range account.Users {
|
for _, user := range account.Users {
|
||||||
s.UserId2AccountId[user.Id] = account.Id
|
s.UserId2AccountId[user.Id] = account.Id
|
||||||
}
|
}
|
||||||
@ -190,7 +266,10 @@ func (s *FileStore) SaveAccount(account *Account) error {
|
|||||||
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
||||||
accountId, accountIdFound := s.PrivateDomain2AccountId[strings.ToLower(domain)]
|
accountId, accountIdFound := s.PrivateDomain2AccountId[strings.ToLower(domain)]
|
||||||
if !accountIdFound {
|
if !accountIdFound {
|
||||||
return nil, status.Errorf(codes.NotFound, "provided domain is not registered or is not private")
|
return nil, status.Errorf(
|
||||||
|
codes.NotFound,
|
||||||
|
"provided domain is not registered or is not private",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := s.GetAccount(accountId)
|
account, err := s.GetAccount(accountId)
|
||||||
@ -232,6 +311,14 @@ func (s *FileStore) GetAccountPeers(accountId string) ([]*Peer, error) {
|
|||||||
return peers, nil
|
return peers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetAllAccounts() (all []*Account) {
|
||||||
|
for _, a := range s.Accounts {
|
||||||
|
all = append(all, a)
|
||||||
|
}
|
||||||
|
|
||||||
|
return all
|
||||||
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetAccount(accountId string) (*Account, error) {
|
func (s *FileStore) GetAccount(accountId string) (*Account, error) {
|
||||||
account, accountFound := s.Accounts[accountId]
|
account, accountFound := s.Accounts[accountId]
|
||||||
if !accountFound {
|
if !accountFound {
|
||||||
@ -265,18 +352,52 @@ func (s *FileStore) GetPeerAccount(peerKey string) (*Account, error) {
|
|||||||
return s.GetAccount(accountId)
|
return s.GetAccount(accountId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) GetGroup(groupID string) (*Group, error) {
|
func (s *FileStore) GetPeerSrcRules(accountId, peerKey string) ([]*Rule, error) {
|
||||||
return nil, nil
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := s.GetAccount(accountId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) SaveGroup(group *Group) error {
|
ruleIDs, ok := s.PeerKeyId2SrcRulesId[peerKey]
|
||||||
return nil
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("no rules for peer: %v", ruleIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) DeleteGroup(groupID string) error {
|
rules := []*Rule{}
|
||||||
return nil
|
for id := range ruleIDs {
|
||||||
|
rule, ok := account.Rules[id]
|
||||||
|
if ok {
|
||||||
|
rules = append(rules, rule)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) ListGroups() ([]*Group, error) {
|
return rules, nil
|
||||||
return nil, nil
|
}
|
||||||
|
|
||||||
|
func (s *FileStore) GetPeerDstRules(accountId, peerKey string) ([]*Rule, error) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := s.GetAccount(accountId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleIDs, ok := s.PeerKeyId2DstRulesId[peerKey]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("no rules for peer: %v", ruleIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := []*Rule{}
|
||||||
|
for id := range ruleIDs {
|
||||||
|
rule, ok := account.Rules[id]
|
||||||
|
if ok {
|
||||||
|
rules = append(rules, rule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return rules, nil
|
||||||
}
|
}
|
||||||
|
@ -17,6 +17,14 @@ type Group struct {
|
|||||||
Peers []string
|
Peers []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *Group) Copy() *Group {
|
||||||
|
return &Group{
|
||||||
|
ID: g.ID,
|
||||||
|
Name: g.Name,
|
||||||
|
Peers: g.Peers[:],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetGroup object of the peers
|
// GetGroup object of the peers
|
||||||
func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) {
|
func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) {
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
|
@ -3,11 +3,12 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||||
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
|
|
||||||
"github.com/golang/protobuf/ptypes/timestamp"
|
"github.com/golang/protobuf/ptypes/timestamp"
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
@ -64,7 +65,6 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
|
func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
|
||||||
|
|
||||||
// todo introduce something more meaningful with the key expiration/rotation
|
// todo introduce something more meaningful with the key expiration/rotation
|
||||||
now := time.Now().Add(24 * time.Hour)
|
now := time.Now().Add(24 * time.Hour)
|
||||||
secs := int64(now.Second())
|
secs := int64(now.Second())
|
||||||
@ -80,7 +80,6 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser
|
|||||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||||
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||||
|
|
||||||
log.Debugf("Sync request from peer %s", req.WgPubKey)
|
log.Debugf("Sync request from peer %s", req.WgPubKey)
|
||||||
|
|
||||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||||
@ -155,7 +154,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Peer, error) {
|
func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Peer, error) {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
reqSetupKey string
|
reqSetupKey string
|
||||||
userId string
|
userId string
|
||||||
@ -240,7 +238,6 @@ func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Pe
|
|||||||
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
|
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
|
||||||
// In case of the successful registration login is also successful
|
// In case of the successful registration login is also successful
|
||||||
func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
|
|
||||||
log.Debugf("Login request from peer %s", req.WgPubKey)
|
log.Debugf("Login request from peer %s", req.WgPubKey)
|
||||||
|
|
||||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||||
@ -309,7 +306,6 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig {
|
func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig {
|
||||||
|
|
||||||
var stuns []*proto.HostConfig
|
var stuns []*proto.HostConfig
|
||||||
for _, stun := range config.Stuns {
|
for _, stun := range config.Stuns {
|
||||||
stuns = append(stuns, &proto.HostConfig{
|
stuns = append(stuns, &proto.HostConfig{
|
||||||
@ -355,7 +351,6 @@ func toPeerConfig(peer *Peer) *proto.PeerConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func toRemotePeerConfig(peers []*Peer) []*proto.RemotePeerConfig {
|
func toRemotePeerConfig(peers []*Peer) []*proto.RemotePeerConfig {
|
||||||
|
|
||||||
remotePeers := []*proto.RemotePeerConfig{}
|
remotePeers := []*proto.RemotePeerConfig{}
|
||||||
for _, rPeer := range peers {
|
for _, rPeer := range peers {
|
||||||
remotePeers = append(remotePeers, &proto.RemotePeerConfig{
|
remotePeers = append(remotePeers, &proto.RemotePeerConfig{
|
||||||
@ -365,11 +360,9 @@ func toRemotePeerConfig(peers []*Peer) []*proto.RemotePeerConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return remotePeers
|
return remotePeers
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func toSyncResponse(config *Config, peer *Peer, peers []*Peer, turnCredentials *TURNCredentials, serial uint64) *proto.SyncResponse {
|
func toSyncResponse(config *Config, peer *Peer, peers []*Peer, turnCredentials *TURNCredentials, serial uint64) *proto.SyncResponse {
|
||||||
|
|
||||||
wtConfig := toWiretrusteeConfig(config, turnCredentials)
|
wtConfig := toWiretrusteeConfig(config, turnCredentials)
|
||||||
|
|
||||||
pConfig := toPeerConfig(peer)
|
pConfig := toPeerConfig(peer)
|
||||||
@ -397,7 +390,6 @@ func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty,
|
|||||||
|
|
||||||
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
|
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
|
||||||
func (s *Server) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.ManagementService_SyncServer) error {
|
func (s *Server) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.ManagementService_SyncServer) error {
|
||||||
|
|
||||||
networkMap, err := s.accountManager.GetNetworkMap(peer.Key)
|
networkMap, err := s.accountManager.GetNetworkMap(peer.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("error getting a list of peers for a peer %s", peer.Key)
|
log.Warnf("error getting a list of peers for a peer %s", peer.Key)
|
||||||
@ -436,7 +428,6 @@ func (s *Server) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.Mana
|
|||||||
// This is used for initiating an Oauth 2 device authorization grant flow
|
// This is used for initiating an Oauth 2 device authorization grant flow
|
||||||
// which will be used by our clients to Login
|
// which will be used by our clients to Login
|
||||||
func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||||
|
|
||||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)
|
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)
|
||||||
|
@ -13,20 +13,33 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Groups is a handler that returns groups of the account
|
|
||||||
type Groups struct {
|
|
||||||
accountManager server.AccountManager
|
|
||||||
authAudience string
|
|
||||||
jwtExtractor jwtclaims.ClaimsExtractor
|
|
||||||
}
|
|
||||||
|
|
||||||
// GroupResponse is a response sent to the client
|
// GroupResponse is a response sent to the client
|
||||||
type GroupResponse struct {
|
type GroupResponse struct {
|
||||||
|
ID string
|
||||||
|
Name string
|
||||||
|
Peers []GroupPeerResponse `json:",omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupPeerResponse is a response sent to the client
|
||||||
|
type GroupPeerResponse struct {
|
||||||
|
Key string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupRequest to create or update group
|
||||||
|
type GroupRequest struct {
|
||||||
ID string
|
ID string
|
||||||
Name string
|
Name string
|
||||||
Peers []string
|
Peers []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Groups is a handler that returns groups of the account
|
||||||
|
type Groups struct {
|
||||||
|
jwtExtractor jwtclaims.ClaimsExtractor
|
||||||
|
accountManager server.AccountManager
|
||||||
|
authAudience string
|
||||||
|
}
|
||||||
|
|
||||||
func NewGroups(accountManager server.AccountManager, authAudience string) *Groups {
|
func NewGroups(accountManager server.AccountManager, authAudience string) *Groups {
|
||||||
return &Groups{
|
return &Groups{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
@ -44,7 +57,12 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, account.Groups)
|
var groups []*GroupResponse
|
||||||
|
for _, g := range account.Groups {
|
||||||
|
groups = append(groups, toGroupResponse(account, g))
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSONObject(w, groups)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Groups) CreateOrUpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Groups) CreateOrUpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
@ -54,7 +72,7 @@ func (h *Groups) CreateOrUpdateGroupHandler(w http.ResponseWriter, r *http.Reque
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var req server.Group
|
var req GroupRequest
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
@ -64,13 +82,19 @@ func (h *Groups) CreateOrUpdateGroupHandler(w http.ResponseWriter, r *http.Reque
|
|||||||
req.ID = xid.New().String()
|
req.ID = xid.New().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.accountManager.SaveGroup(account.Id, &req); err != nil {
|
group := server.Group{
|
||||||
|
ID: req.ID,
|
||||||
|
Name: req.Name,
|
||||||
|
Peers: req.Peers,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.accountManager.SaveGroup(account.Id, &group); err != nil {
|
||||||
log.Errorf("failed updating group %s under account %s %v", req.ID, account.Id, err)
|
log.Errorf("failed updating group %s under account %s %v", req.ID, account.Id, err)
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, &req)
|
writeJSONObject(w, toGroupResponse(account, &group))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
@ -117,7 +141,7 @@ func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writeJSONObject(w, group)
|
writeJSONObject(w, toGroupResponse(account, group))
|
||||||
default:
|
default:
|
||||||
http.Error(w, "", http.StatusNotFound)
|
http.Error(w, "", http.StatusNotFound)
|
||||||
}
|
}
|
||||||
@ -133,3 +157,29 @@ func (h *Groups) getGroupAccount(r *http.Request) (*server.Account, error) {
|
|||||||
|
|
||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toGroupResponse(account *server.Account, group *server.Group) *GroupResponse {
|
||||||
|
cache := make(map[string]GroupPeerResponse)
|
||||||
|
gr := GroupResponse{
|
||||||
|
ID: group.ID,
|
||||||
|
Name: group.Name,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pid := range group.Peers {
|
||||||
|
peerResp, ok := cache[pid]
|
||||||
|
if !ok {
|
||||||
|
peer, ok := account.Peers[pid]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
peerResp = GroupPeerResponse{
|
||||||
|
Key: peer.Key,
|
||||||
|
Name: peer.Name,
|
||||||
|
}
|
||||||
|
cache[pid] = peerResp
|
||||||
|
}
|
||||||
|
gr.Peers = append(gr.Peers, peerResp)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &gr
|
||||||
|
}
|
||||||
|
211
management/server/http/handler/rules.go
Normal file
211
management/server/http/handler/rules.go
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
|
"github.com/rs/xid"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const FlowBidirectString = "bidirect"
|
||||||
|
|
||||||
|
// RuleResponse is a response sent to the client
|
||||||
|
type RuleResponse struct {
|
||||||
|
ID string
|
||||||
|
Name string
|
||||||
|
Source []RuleGroupResponse
|
||||||
|
Destination []RuleGroupResponse
|
||||||
|
Flow string
|
||||||
|
}
|
||||||
|
|
||||||
|
// RuleGroupResponse is a response sent to the client
|
||||||
|
type RuleGroupResponse struct {
|
||||||
|
ID string
|
||||||
|
Name string
|
||||||
|
PeersCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
// RuleRequest to create or update rule
|
||||||
|
type RuleRequest struct {
|
||||||
|
ID string
|
||||||
|
Name string
|
||||||
|
Source []string
|
||||||
|
Destination []string
|
||||||
|
Flow string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rules is a handler that returns rules of the account
|
||||||
|
type Rules struct {
|
||||||
|
jwtExtractor jwtclaims.ClaimsExtractor
|
||||||
|
accountManager server.AccountManager
|
||||||
|
authAudience string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRules(accountManager server.AccountManager, authAudience string) *Rules {
|
||||||
|
return &Rules{
|
||||||
|
accountManager: accountManager,
|
||||||
|
authAudience: authAudience,
|
||||||
|
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllRulesHandler list for the account
|
||||||
|
func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
account, err := h.getRuleAccount(r)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var rules []*RuleResponse
|
||||||
|
for _, r := range account.Rules {
|
||||||
|
rules = append(rules, toRuleResponse(account, r))
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSONObject(w, rules)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Rules) CreateOrUpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
account, err := h.getRuleAccount(r)
|
||||||
|
if err != nil {
|
||||||
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req RuleRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Method == http.MethodPost {
|
||||||
|
req.ID = xid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := server.Rule{
|
||||||
|
ID: req.ID,
|
||||||
|
Name: req.Name,
|
||||||
|
Source: req.Source,
|
||||||
|
Destination: req.Destination,
|
||||||
|
}
|
||||||
|
|
||||||
|
switch req.Flow {
|
||||||
|
case FlowBidirectString:
|
||||||
|
rule.Flow = server.TrafficFlowBidirect
|
||||||
|
default:
|
||||||
|
http.Error(w, "unknown flow type", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.accountManager.SaveRule(account.Id, &rule); err != nil {
|
||||||
|
log.Errorf("failed updating rule %s under account %s %v", req.ID, account.Id, err)
|
||||||
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSONObject(w, &req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
account, err := h.getRuleAccount(r)
|
||||||
|
if err != nil {
|
||||||
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
aID := account.Id
|
||||||
|
|
||||||
|
rID := mux.Vars(r)["id"]
|
||||||
|
if len(rID) == 0 {
|
||||||
|
http.Error(w, "invalid rule ID", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.accountManager.DeleteRule(aID, rID); err != nil {
|
||||||
|
log.Errorf("failed delete rule %s under account %s %v", rID, aID, err)
|
||||||
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSONObject(w, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
account, err := h.getRuleAccount(r)
|
||||||
|
if err != nil {
|
||||||
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
ruleID := mux.Vars(r)["id"]
|
||||||
|
if len(ruleID) == 0 {
|
||||||
|
http.Error(w, "invalid rule ID", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rule, err := h.accountManager.GetRule(account.Id, ruleID)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "rule not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSONObject(w, toRuleResponse(account, rule))
|
||||||
|
default:
|
||||||
|
http.Error(w, "", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Rules) getRuleAccount(r *http.Request) (*server.Account, error) {
|
||||||
|
jwtClaims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||||
|
|
||||||
|
account, err := h.accountManager.GetAccountWithAuthorizationClaims(jwtClaims)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return account, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toRuleResponse(account *server.Account, rule *server.Rule) *RuleResponse {
|
||||||
|
gr := RuleResponse{
|
||||||
|
ID: rule.ID,
|
||||||
|
Name: rule.Name,
|
||||||
|
}
|
||||||
|
|
||||||
|
switch rule.Flow {
|
||||||
|
case server.TrafficFlowBidirect:
|
||||||
|
gr.Flow = FlowBidirectString
|
||||||
|
default:
|
||||||
|
gr.Flow = "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, gid := range rule.Source {
|
||||||
|
if group, ok := account.Groups[gid]; ok {
|
||||||
|
gr.Source = append(gr.Source, RuleGroupResponse{
|
||||||
|
ID: group.ID,
|
||||||
|
Name: group.Name,
|
||||||
|
PeersCount: len(group.Peers),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, gid := range rule.Destination {
|
||||||
|
if group, ok := account.Groups[gid]; ok {
|
||||||
|
gr.Destination = append(gr.Destination, RuleGroupResponse{
|
||||||
|
ID: group.ID,
|
||||||
|
Name: group.Name,
|
||||||
|
PeersCount: len(group.Peers),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &gr
|
||||||
|
}
|
211
management/server/http/handler/rules_test.go
Normal file
211
management/server/http/handler/rules_test.go
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
|
|
||||||
|
"github.com/magiconair/properties/assert"
|
||||||
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func initRulesTestData(rules ...*server.Rule) *Rules {
|
||||||
|
return &Rules{
|
||||||
|
accountManager: &mock_server.MockAccountManager{
|
||||||
|
SaveRuleFunc: func(_ string, rule *server.Rule) error {
|
||||||
|
if !strings.HasPrefix(rule.ID, "id-") {
|
||||||
|
rule.ID = "id-was-set"
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
GetRuleFunc: func(_, ruleID string) (*server.Rule, error) {
|
||||||
|
if ruleID != "idoftherule" {
|
||||||
|
return nil, fmt.Errorf("not found")
|
||||||
|
}
|
||||||
|
return &server.Rule{
|
||||||
|
ID: "idoftherule",
|
||||||
|
Name: "Rule",
|
||||||
|
Source: []string{"idofsrcrule"},
|
||||||
|
Destination: []string{"idofdestrule"},
|
||||||
|
Flow: server.TrafficFlowBidirect,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||||
|
return &server.Account{
|
||||||
|
Id: claims.AccountId,
|
||||||
|
Domain: "hotmail.com",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
authAudience: "",
|
||||||
|
jwtExtractor: jwtclaims.ClaimsExtractor{
|
||||||
|
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
|
||||||
|
return jwtclaims.AuthorizationClaims{
|
||||||
|
UserId: "test_user",
|
||||||
|
Domain: "hotmail.com",
|
||||||
|
AccountId: "test_id",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesGetRule(t *testing.T) {
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
expectedStatus int
|
||||||
|
expectedBody bool
|
||||||
|
requestType string
|
||||||
|
requestPath string
|
||||||
|
requestBody io.Reader
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "GetRule OK",
|
||||||
|
expectedBody: true,
|
||||||
|
requestType: http.MethodGet,
|
||||||
|
requestPath: "/api/rules/idoftherule",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GetRule not found",
|
||||||
|
requestType: http.MethodGet,
|
||||||
|
requestPath: "/api/rules/notexists",
|
||||||
|
expectedStatus: http.StatusNotFound,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := &server.Rule{
|
||||||
|
ID: "idoftherule",
|
||||||
|
Name: "Rule",
|
||||||
|
}
|
||||||
|
|
||||||
|
p := initRulesTestData(rule)
|
||||||
|
|
||||||
|
for _, tc := range tt {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
router.HandleFunc("/api/rules/{id}", p.GetRuleHandler).Methods("GET")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
res := recorder.Result()
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if status := recorder.Code; status != tc.expectedStatus {
|
||||||
|
t.Errorf("handler returned wrong status code: got %v want %v",
|
||||||
|
status, tc.expectedStatus)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tc.expectedBody {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("I don't know what I expected; %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var got RuleResponse
|
||||||
|
if err = json.Unmarshal(content, &got); err != nil {
|
||||||
|
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, got.ID, rule.ID)
|
||||||
|
assert.Equal(t, got.Name, rule.Name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRulesSaveRule(t *testing.T) {
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
expectedStatus int
|
||||||
|
expectedBody bool
|
||||||
|
expectedRule *server.Rule
|
||||||
|
requestType string
|
||||||
|
requestPath string
|
||||||
|
requestBody io.Reader
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "SaveRule POST OK",
|
||||||
|
requestType: http.MethodPost,
|
||||||
|
requestPath: "/api/rules",
|
||||||
|
requestBody: bytes.NewBuffer(
|
||||||
|
[]byte(`{"Name":"Default POSTed Rule","Flow":"bidirect"}`)),
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedBody: true,
|
||||||
|
expectedRule: &server.Rule{
|
||||||
|
ID: "id-was-set",
|
||||||
|
Name: "Default POSTed Rule",
|
||||||
|
Flow: server.TrafficFlowBidirect,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SaveRule PUT OK",
|
||||||
|
requestType: http.MethodPut,
|
||||||
|
requestPath: "/api/rules",
|
||||||
|
requestBody: bytes.NewBuffer(
|
||||||
|
[]byte(`{"ID":"id-existed","Name":"Default POSTed Rule","Flow":"bidirect"}`)),
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedRule: &server.Rule{
|
||||||
|
ID: "id-existed",
|
||||||
|
Name: "Default POSTed Rule",
|
||||||
|
Flow: server.TrafficFlowBidirect,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
p := initRulesTestData()
|
||||||
|
|
||||||
|
for _, tc := range tt {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
|
router := mux.NewRouter()
|
||||||
|
router.HandleFunc("/api/rules", p.CreateOrUpdateRuleHandler).Methods("PUT", "POST")
|
||||||
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
|
res := recorder.Result()
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
content, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("I don't know what I expected; %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status := recorder.Code; status != tc.expectedStatus {
|
||||||
|
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
|
||||||
|
status, tc.expectedStatus, string(content))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tc.expectedBody {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
got := &RuleRequest{}
|
||||||
|
if err = json.Unmarshal(content, &got); err != nil {
|
||||||
|
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.requestType != http.MethodPost {
|
||||||
|
assert.Equal(t, got.ID, tc.expectedRule.ID)
|
||||||
|
}
|
||||||
|
assert.Equal(t, got.Name, tc.expectedRule.Name)
|
||||||
|
assert.Equal(t, got.Flow, "bidirect")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -96,6 +96,7 @@ func (s *Server) Start() error {
|
|||||||
r.Use(jwtMiddleware.Handler, corsMiddleware.Handler)
|
r.Use(jwtMiddleware.Handler, corsMiddleware.Handler)
|
||||||
|
|
||||||
groupsHandler := handler.NewGroups(s.accountManager, s.config.AuthAudience)
|
groupsHandler := handler.NewGroups(s.accountManager, s.config.AuthAudience)
|
||||||
|
rulesHandler := handler.NewRules(s.accountManager, s.config.AuthAudience)
|
||||||
peersHandler := handler.NewPeers(s.accountManager, s.config.AuthAudience)
|
peersHandler := handler.NewPeers(s.accountManager, s.config.AuthAudience)
|
||||||
keysHandler := handler.NewSetupKeysHandler(s.accountManager, s.config.AuthAudience)
|
keysHandler := handler.NewSetupKeysHandler(s.accountManager, s.config.AuthAudience)
|
||||||
r.HandleFunc("/api/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS")
|
r.HandleFunc("/api/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS")
|
||||||
@ -112,6 +113,12 @@ func (s *Server) Start() error {
|
|||||||
r.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).
|
r.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).
|
||||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||||
|
|
||||||
|
r.HandleFunc("/api/rules", rulesHandler.GetAllRulesHandler).Methods("GET", "OPTIONS")
|
||||||
|
r.HandleFunc("/api/rules", rulesHandler.CreateOrUpdateRuleHandler).
|
||||||
|
Methods("POST", "PUT", "OPTIONS")
|
||||||
|
r.HandleFunc("/api/rules/{id}", rulesHandler.GetRuleHandler).Methods("GET", "OPTIONS")
|
||||||
|
r.HandleFunc("/api/rules/{id}", rulesHandler.DeleteRuleHandler).Methods("DELETE", "OPTIONS")
|
||||||
|
|
||||||
r.HandleFunc("/api/groups", groupsHandler.GetAllGroupsHandler).Methods("GET", "OPTIONS")
|
r.HandleFunc("/api/groups", groupsHandler.GetAllGroupsHandler).Methods("GET", "OPTIONS")
|
||||||
r.HandleFunc("/api/groups", groupsHandler.CreateOrUpdateGroupHandler).
|
r.HandleFunc("/api/groups", groupsHandler.CreateOrUpdateGroupHandler).
|
||||||
Methods("POST", "PUT", "OPTIONS")
|
Methods("POST", "PUT", "OPTIONS")
|
||||||
|
@ -3,6 +3,13 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
@ -11,12 +18,6 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -39,8 +40,7 @@ const (
|
|||||||
|
|
||||||
// registerPeers registers peersNum peers on the management service and returns their Wireguard keys
|
// registerPeers registers peersNum peers on the management service and returns their Wireguard keys
|
||||||
func registerPeers(peersNum int, client mgmtProto.ManagementServiceClient) ([]*wgtypes.Key, error) {
|
func registerPeers(peersNum int, client mgmtProto.ManagementServiceClient) ([]*wgtypes.Key, error) {
|
||||||
|
peers := []*wgtypes.Key{}
|
||||||
var peers = []*wgtypes.Key{}
|
|
||||||
for i := 0; i < peersNum; i++ {
|
for i := 0; i < peersNum; i++ {
|
||||||
key, err := wgtypes.GeneratePrivateKey()
|
key, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -60,7 +60,6 @@ func registerPeers(peersNum int, client mgmtProto.ManagementServiceClient) ([]*w
|
|||||||
|
|
||||||
// getServerKey gets Management Service Wireguard public key
|
// getServerKey gets Management Service Wireguard public key
|
||||||
func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error) {
|
func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error) {
|
||||||
|
|
||||||
keyResp, err := client.GetServerKey(context.TODO(), &mgmtProto.Empty{})
|
keyResp, err := client.GetServerKey(context.TODO(), &mgmtProto.Empty{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -75,7 +74,6 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Test_SyncProtocol(t *testing.T) {
|
func Test_SyncProtocol(t *testing.T) {
|
||||||
|
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
err := util.CopyFileContents("testdata/store.json", filepath.Join(dir, "store.json"))
|
err := util.CopyFileContents("testdata/store.json", filepath.Join(dir, "store.json"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -263,7 +261,6 @@ func Test_SyncProtocol(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServiceClient) (*mgmtProto.LoginResponse, error) {
|
func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServiceClient) (*mgmtProto.LoginResponse, error) {
|
||||||
|
|
||||||
serverKey, err := getServerKey(client)
|
serverKey, err := getServerKey(client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -298,11 +295,9 @@ func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServ
|
|||||||
}
|
}
|
||||||
|
|
||||||
return loginResp, nil
|
return loginResp, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||||
|
|
||||||
testingServerKey, err := wgtypes.GeneratePrivateKey()
|
testingServerKey, err := wgtypes.GeneratePrivateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
|
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
|
||||||
@ -362,7 +357,6 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
|
||||||
mgmtServer := &Server{
|
mgmtServer := &Server{
|
||||||
wgKey: testingServerKey,
|
wgKey: testingServerKey,
|
||||||
config: &Config{
|
config: &Config{
|
||||||
@ -397,7 +391,6 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func startManagement(t *testing.T, port int, config *Config) (*grpc.Server, error) {
|
func startManagement(t *testing.T, port int, config *Config) (*grpc.Server, error) {
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
|
lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -408,7 +401,10 @@ func startManagement(t *testing.T, port int, config *Config) (*grpc.Server, erro
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
peersUpdateManager := NewPeersUpdateManager()
|
peersUpdateManager := NewPeersUpdateManager()
|
||||||
accountManager := NewManager(store, peersUpdateManager, nil)
|
accountManager, err := BuildManager(store, peersUpdateManager, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||||
mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager)
|
mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -2,8 +2,6 @@ package server_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
server "github.com/netbirdio/netbird/management/server"
|
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
@ -13,6 +11,9 @@ import (
|
|||||||
sync2 "sync"
|
sync2 "sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
server "github.com/netbirdio/netbird/management/server"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
pb "github.com/golang/protobuf/proto" //nolint
|
pb "github.com/golang/protobuf/proto" //nolint
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@ -31,7 +32,6 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var _ = Describe("Management service", func() {
|
var _ = Describe("Management service", func() {
|
||||||
|
|
||||||
var (
|
var (
|
||||||
addr string
|
addr string
|
||||||
s *grpc.Server
|
s *grpc.Server
|
||||||
@ -66,7 +66,6 @@ var _ = Describe("Management service", func() {
|
|||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
serverPubKey, err = wgtypes.ParseKey(resp.Key)
|
serverPubKey, err = wgtypes.ParseKey(resp.Key)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
AfterEach(func() {
|
AfterEach(func() {
|
||||||
@ -78,7 +77,6 @@ var _ = Describe("Management service", func() {
|
|||||||
|
|
||||||
Context("when calling IsHealthy endpoint", func() {
|
Context("when calling IsHealthy endpoint", func() {
|
||||||
Specify("a non-error result is returned", func() {
|
Specify("a non-error result is returned", func() {
|
||||||
|
|
||||||
healthy, err := client.IsHealthy(context.TODO(), &mgmtProto.Empty{})
|
healthy, err := client.IsHealthy(context.TODO(), &mgmtProto.Empty{})
|
||||||
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
@ -87,7 +85,6 @@ var _ = Describe("Management service", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
Context("when calling Sync endpoint", func() {
|
Context("when calling Sync endpoint", func() {
|
||||||
|
|
||||||
Context("when there is a new peer registered", func() {
|
Context("when there is a new peer registered", func() {
|
||||||
Specify("a proper configuration is returned", func() {
|
Specify("a proper configuration is returned", func() {
|
||||||
key, _ := wgtypes.GenerateKey()
|
key, _ := wgtypes.GenerateKey()
|
||||||
@ -168,7 +165,6 @@ var _ = Describe("Management service", func() {
|
|||||||
Expect(resp.GetRemotePeers()).To(HaveLen(2))
|
Expect(resp.GetRemotePeers()).To(HaveLen(2))
|
||||||
peers := []string{resp.GetRemotePeers()[0].WgPubKey, resp.GetRemotePeers()[1].WgPubKey}
|
peers := []string{resp.GetRemotePeers()[0].WgPubKey, resp.GetRemotePeers()[1].WgPubKey}
|
||||||
Expect(peers).To(ContainElements(key1.PublicKey().String(), key2.PublicKey().String()))
|
Expect(peers).To(ContainElements(key1.PublicKey().String(), key2.PublicKey().String()))
|
||||||
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -211,7 +207,6 @@ var _ = Describe("Management service", func() {
|
|||||||
resp = &mgmtProto.SyncResponse{}
|
resp = &mgmtProto.SyncResponse{}
|
||||||
err = pb.Unmarshal(decryptedBytes, resp)
|
err = pb.Unmarshal(decryptedBytes, resp)
|
||||||
wg.Done()
|
wg.Done()
|
||||||
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// register a new peer
|
// register a new peer
|
||||||
@ -229,7 +224,6 @@ var _ = Describe("Management service", func() {
|
|||||||
|
|
||||||
Context("when calling GetServerKey endpoint", func() {
|
Context("when calling GetServerKey endpoint", func() {
|
||||||
Specify("a public Wireguard key of the service is returned", func() {
|
Specify("a public Wireguard key of the service is returned", func() {
|
||||||
|
|
||||||
resp, err := client.GetServerKey(context.TODO(), &mgmtProto.Empty{})
|
resp, err := client.GetServerKey(context.TODO(), &mgmtProto.Empty{})
|
||||||
|
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
@ -241,15 +235,12 @@ var _ = Describe("Management service", func() {
|
|||||||
key, err := wgtypes.ParseKey(resp.Key)
|
key, err := wgtypes.ParseKey(resp.Key)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
Expect(key).ToNot(BeNil())
|
Expect(key).ToNot(BeNil())
|
||||||
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("when calling Login endpoint", func() {
|
Context("when calling Login endpoint", func() {
|
||||||
|
|
||||||
Context("with an invalid setup key", func() {
|
Context("with an invalid setup key", func() {
|
||||||
Specify("an error is returned", func() {
|
Specify("an error is returned", func() {
|
||||||
|
|
||||||
key, _ := wgtypes.GenerateKey()
|
key, _ := wgtypes.GenerateKey()
|
||||||
message, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.LoginRequest{SetupKey: "invalid setup key"})
|
message, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.LoginRequest{SetupKey: "invalid setup key"})
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
@ -261,24 +252,20 @@ var _ = Describe("Management service", func() {
|
|||||||
|
|
||||||
Expect(err).To(HaveOccurred())
|
Expect(err).To(HaveOccurred())
|
||||||
Expect(resp).To(BeNil())
|
Expect(resp).To(BeNil())
|
||||||
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("with a valid setup key", func() {
|
Context("with a valid setup key", func() {
|
||||||
It("a non error result is returned", func() {
|
It("a non error result is returned", func() {
|
||||||
|
|
||||||
key, _ := wgtypes.GenerateKey()
|
key, _ := wgtypes.GenerateKey()
|
||||||
resp := loginPeerWithValidSetupKey(serverPubKey, key, client)
|
resp := loginPeerWithValidSetupKey(serverPubKey, key, client)
|
||||||
|
|
||||||
Expect(resp).ToNot(BeNil())
|
Expect(resp).ToNot(BeNil())
|
||||||
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("with a registered peer", func() {
|
Context("with a registered peer", func() {
|
||||||
It("a non error result is returned", func() {
|
It("a non error result is returned", func() {
|
||||||
|
|
||||||
key, _ := wgtypes.GenerateKey()
|
key, _ := wgtypes.GenerateKey()
|
||||||
regResp := loginPeerWithValidSetupKey(serverPubKey, key, client)
|
regResp := loginPeerWithValidSetupKey(serverPubKey, key, client)
|
||||||
Expect(regResp).NotTo(BeNil())
|
Expect(regResp).NotTo(BeNil())
|
||||||
@ -324,7 +311,6 @@ var _ = Describe("Management service", func() {
|
|||||||
Context("when there are 50 peers registered under one account", func() {
|
Context("when there are 50 peers registered under one account", func() {
|
||||||
Context("when there are 10 more peers registered under the same account", func() {
|
Context("when there are 10 more peers registered under the same account", func() {
|
||||||
Specify("all of the 50 peers will get updates of 10 newly registered peers", func() {
|
Specify("all of the 50 peers will get updates of 10 newly registered peers", func() {
|
||||||
|
|
||||||
initialPeers := 20
|
initialPeers := 20
|
||||||
additionalPeers := 10
|
additionalPeers := 10
|
||||||
|
|
||||||
@ -397,7 +383,6 @@ var _ = Describe("Management service", func() {
|
|||||||
|
|
||||||
Context("when there are peers registered under one account concurrently", func() {
|
Context("when there are peers registered under one account concurrently", func() {
|
||||||
Specify("then there are no duplicate IPs", func() {
|
Specify("then there are no duplicate IPs", func() {
|
||||||
|
|
||||||
initialPeers := 30
|
initialPeers := 30
|
||||||
|
|
||||||
ipChannel := make(chan string, 20)
|
ipChannel := make(chan string, 20)
|
||||||
@ -423,7 +408,6 @@ var _ = Describe("Management service", func() {
|
|||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
|
||||||
ipChannel <- resp.GetPeerConfig().Address
|
ipChannel <- resp.GetPeerConfig().Address
|
||||||
|
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -443,6 +427,7 @@ var _ = Describe("Management service", func() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse {
|
func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse {
|
||||||
|
defer GinkgoRecover()
|
||||||
|
|
||||||
meta := &mgmtProto.PeerSystemMeta{
|
meta := &mgmtProto.PeerSystemMeta{
|
||||||
Hostname: key.PublicKey().String(),
|
Hostname: key.PublicKey().String(),
|
||||||
@ -467,7 +452,6 @@ func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, clien
|
|||||||
err = encryption.DecryptMessage(serverPubKey, key, resp.Body, loginResp)
|
err = encryption.DecryptMessage(serverPubKey, key, resp.Body, loginResp)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
return loginResp
|
return loginResp
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn) {
|
func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn) {
|
||||||
@ -496,7 +480,10 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
|
|||||||
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
||||||
}
|
}
|
||||||
peersUpdateManager := server.NewPeersUpdateManager()
|
peersUpdateManager := server.NewPeersUpdateManager()
|
||||||
accountManager := server.NewManager(store, peersUpdateManager, nil)
|
accountManager, err := server.BuildManager(store, peersUpdateManager, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed creating a manager: %v", err)
|
||||||
|
}
|
||||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager)
|
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
@ -33,6 +33,10 @@ type MockAccountManager struct {
|
|||||||
GroupAddPeerFunc func(accountID, groupID, peerKey string) error
|
GroupAddPeerFunc func(accountID, groupID, peerKey string) error
|
||||||
GroupDeletePeerFunc func(accountID, groupID, peerKey string) error
|
GroupDeletePeerFunc func(accountID, groupID, peerKey string) error
|
||||||
GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error)
|
GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error)
|
||||||
|
GetRuleFunc func(accountID, ruleID string) (*server.Rule, error)
|
||||||
|
SaveRuleFunc func(accountID string, rule *server.Rule) error
|
||||||
|
DeleteRuleFunc func(accountID, ruleID string) error
|
||||||
|
ListRulesFunc func(accountID string) ([]*server.Rule, error)
|
||||||
GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error)
|
GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,7 +45,6 @@ func (am *MockAccountManager) GetUsersFromAccount(accountID string) ([]*server.U
|
|||||||
return am.GetUsersFromAccountFunc(accountID)
|
return am.GetUsersFromAccountFunc(accountID)
|
||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount not implemented")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) GetOrCreateAccountByUser(
|
func (am *MockAccountManager) GetOrCreateAccountByUser(
|
||||||
@ -207,7 +210,7 @@ func (am *MockAccountManager) SaveGroup(accountID string, group *server.Group) e
|
|||||||
if am.SaveGroupFunc != nil {
|
if am.SaveGroupFunc != nil {
|
||||||
return am.SaveGroupFunc(accountID, group)
|
return am.SaveGroupFunc(accountID, group)
|
||||||
}
|
}
|
||||||
return status.Errorf(codes.Unimplemented, "method UpdateGroup not implemented")
|
return status.Errorf(codes.Unimplemented, "method SaveGroup not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *MockAccountManager) DeleteGroup(accountID, groupID string) error {
|
func (am *MockAccountManager) DeleteGroup(accountID, groupID string) error {
|
||||||
@ -244,3 +247,31 @@ func (am *MockAccountManager) GroupListPeers(accountID, groupID string) ([]*serv
|
|||||||
}
|
}
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GroupListPeers not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method GroupListPeers not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *MockAccountManager) GetRule(accountID, ruleID string) (*server.Rule, error) {
|
||||||
|
if am.GetRuleFunc != nil {
|
||||||
|
return am.GetRuleFunc(accountID, ruleID)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GetRule not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *MockAccountManager) SaveRule(accountID string, rule *server.Rule) error {
|
||||||
|
if am.SaveRuleFunc != nil {
|
||||||
|
return am.SaveRuleFunc(accountID, rule)
|
||||||
|
}
|
||||||
|
return status.Errorf(codes.Unimplemented, "method SaveRule not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *MockAccountManager) DeleteRule(accountID, ruleID string) error {
|
||||||
|
if am.DeleteRuleFunc != nil {
|
||||||
|
return am.DeleteRuleFunc(accountID, ruleID)
|
||||||
|
}
|
||||||
|
return status.Errorf(codes.Unimplemented, "method DeleteRule not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *MockAccountManager) ListRules(accountID string) ([]*server.Rule, error) {
|
||||||
|
if am.ListRulesFunc != nil {
|
||||||
|
return am.ListRulesFunc(accountID)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method ListRules not implemented")
|
||||||
|
}
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PeerSystemMeta is a metadata of a Peer machine system
|
// PeerSystemMeta is a metadata of a Peer machine system
|
||||||
@ -97,7 +98,11 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerKey string, connected boo
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RenamePeer changes peer's name
|
// RenamePeer changes peer's name
|
||||||
func (am *DefaultAccountManager) RenamePeer(accountId string, peerKey string, newName string) (*Peer, error) {
|
func (am *DefaultAccountManager) RenamePeer(
|
||||||
|
accountId string,
|
||||||
|
peerKey string,
|
||||||
|
newName string,
|
||||||
|
) (*Peer, error) {
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
defer am.mux.Unlock()
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
@ -149,7 +154,8 @@ func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*
|
|||||||
RemotePeers: []*proto.RemotePeerConfig{},
|
RemotePeers: []*proto.RemotePeerConfig{},
|
||||||
RemotePeersIsEmpty: true,
|
RemotePeersIsEmpty: true,
|
||||||
},
|
},
|
||||||
}})
|
},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -180,7 +186,8 @@ func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*
|
|||||||
RemotePeers: update,
|
RemotePeers: update,
|
||||||
RemotePeersIsEmpty: len(update) == 0,
|
RemotePeersIsEmpty: len(update) == 0,
|
||||||
},
|
},
|
||||||
}})
|
},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -220,12 +227,48 @@ func (am *DefaultAccountManager) GetNetworkMap(peerKey string) (*NetworkMap, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
var res []*Peer
|
var res []*Peer
|
||||||
for _, peer := range account.Peers {
|
srcRules, err := am.Store.GetPeerSrcRules(account.Id, peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return &NetworkMap{
|
||||||
|
Peers: res,
|
||||||
|
Network: account.Network.Copy(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dstRules, err := am.Store.GetPeerDstRules(account.Id, peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return &NetworkMap{
|
||||||
|
Peers: res,
|
||||||
|
Network: account.Network.Copy(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
groups := map[string]*Group{}
|
||||||
|
for _, r := range srcRules {
|
||||||
|
if r.Flow == TrafficFlowBidirect {
|
||||||
|
for _, gid := range r.Destination {
|
||||||
|
groups[gid] = account.Groups[gid]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range dstRules {
|
||||||
|
if r.Flow == TrafficFlowBidirect {
|
||||||
|
for _, gid := range r.Source {
|
||||||
|
groups[gid] = account.Groups[gid]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, g := range groups {
|
||||||
|
for _, pid := range g.Peers {
|
||||||
|
peer := account.Peers[pid]
|
||||||
// exclude original peer
|
// exclude original peer
|
||||||
if peer.Key != peerKey {
|
if peer.Key != peerKey {
|
||||||
res = append(res, peer.Copy())
|
res = append(res, peer.Copy())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &NetworkMap{
|
return &NetworkMap{
|
||||||
Peers: res,
|
Peers: res,
|
||||||
@ -240,7 +283,11 @@ func (am *DefaultAccountManager) GetNetworkMap(peerKey string) (*NetworkMap, err
|
|||||||
// to it. We also add the User ID to the peer metadata to identify registrant.
|
// to it. We also add the User ID to the peer metadata to identify registrant.
|
||||||
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
|
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
|
||||||
// The peer property is just a placeholder for the Peer properties to pass further
|
// The peer property is just a placeholder for the Peer properties to pass further
|
||||||
func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *Peer) (*Peer, error) {
|
func (am *DefaultAccountManager) AddPeer(
|
||||||
|
setupKey string,
|
||||||
|
userID string,
|
||||||
|
peer *Peer,
|
||||||
|
) (*Peer, error) {
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
defer am.mux.Unlock()
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
@ -252,17 +299,28 @@ func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *P
|
|||||||
if len(upperKey) != 0 {
|
if len(upperKey) != 0 {
|
||||||
account, err = am.Store.GetAccountBySetupKey(upperKey)
|
account, err = am.Store.GetAccountBySetupKey(upperKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "unable to register peer, unable to find account with setupKey %s", upperKey)
|
return nil, status.Errorf(
|
||||||
|
codes.NotFound,
|
||||||
|
"unable to register peer, unable to find account with setupKey %s",
|
||||||
|
upperKey,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
sk = getAccountSetupKeyByKey(account, upperKey)
|
sk = getAccountSetupKeyByKey(account, upperKey)
|
||||||
if sk == nil {
|
if sk == nil {
|
||||||
// shouldn't happen actually
|
// shouldn't happen actually
|
||||||
return nil, status.Errorf(codes.NotFound, "unable to register peer, unknown setupKey %s", upperKey)
|
return nil, status.Errorf(
|
||||||
|
codes.NotFound,
|
||||||
|
"unable to register peer, unknown setupKey %s",
|
||||||
|
upperKey,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !sk.IsValid() {
|
if !sk.IsValid() {
|
||||||
return nil, status.Errorf(codes.FailedPrecondition, "unable to register peer, its setup key is invalid (expired, overused or revoked)")
|
return nil, status.Errorf(
|
||||||
|
codes.FailedPrecondition,
|
||||||
|
"unable to register peer, its setup key is invalid (expired, overused or revoked)",
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if len(userID) != 0 {
|
} else if len(userID) != 0 {
|
||||||
@ -293,6 +351,13 @@ func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *P
|
|||||||
Status: &PeerStatus{Connected: false, LastSeen: time.Now()},
|
Status: &PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add peer to 'All' group
|
||||||
|
group, err := account.GetGroupAll()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
group.Peers = append(group.Peers, newPeer.Key)
|
||||||
|
|
||||||
account.Peers[newPeer.Key] = newPeer
|
account.Peers[newPeer.Key] = newPeer
|
||||||
if len(upperKey) != 0 {
|
if len(upperKey) != 0 {
|
||||||
account.SetupKeys[sk.Key] = sk.IncrementUsage()
|
account.SetupKeys[sk.Key] = sk.IncrementUsage()
|
||||||
@ -305,5 +370,4 @@ func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *P
|
|||||||
}
|
}
|
||||||
|
|
||||||
return newPeer, nil
|
return newPeer, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/rs/xid"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAccountManager_GetNetworkMap(t *testing.T) {
|
func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||||
@ -70,7 +72,151 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if networkMap.Peers[0].Key != peerKey2.PublicKey().String() {
|
if networkMap.Peers[0].Key != peerKey2.PublicKey().String() {
|
||||||
t.Errorf("expecting Account NetworkMap to have peer with a key %s, got %s", peerKey2.PublicKey().String(), networkMap.Peers[0].Key)
|
t.Errorf(
|
||||||
|
"expecting Account NetworkMap to have peer with a key %s, got %s",
|
||||||
|
peerKey2.PublicKey().String(),
|
||||||
|
networkMap.Peers[0].Key,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccountManager_GetNetworkMapWithRule(t *testing.T) {
|
||||||
|
manager, err := createManager(t)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedId := "test_account"
|
||||||
|
userId := "account_creator"
|
||||||
|
account, err := manager.AddAccount(expectedId, userId, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var setupKey *SetupKey
|
||||||
|
for _, key := range account.SetupKeys {
|
||||||
|
if key.Type == SetupKeyReusable {
|
||||||
|
setupKey = key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
peerKey1, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = manager.AddPeer(setupKey.Key, "", &Peer{
|
||||||
|
Key: peerKey1.PublicKey().String(),
|
||||||
|
Meta: PeerSystemMeta{},
|
||||||
|
Name: "test-peer-2",
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expecting peer to be added, got failure %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
peerKey2, err := wgtypes.GeneratePrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = manager.AddPeer(setupKey.Key, "", &Peer{
|
||||||
|
Key: peerKey2.PublicKey().String(),
|
||||||
|
Meta: PeerSystemMeta{},
|
||||||
|
Name: "test-peer-2",
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expecting peer to be added, got failure %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err := manager.ListRules(account.Id)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expecting to get a list of rules, got failure %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.DeleteRule(account.Id, rules[0].ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expecting to delete 1 group, got failure %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
group1 Group
|
||||||
|
group2 Group
|
||||||
|
rule Rule
|
||||||
|
)
|
||||||
|
|
||||||
|
group1.ID = xid.New().String()
|
||||||
|
group2.ID = xid.New().String()
|
||||||
|
group1.Name = "src"
|
||||||
|
group2.Name = "dst"
|
||||||
|
rule.ID = xid.New().String()
|
||||||
|
group1.Peers = append(group1.Peers, peerKey1.PublicKey().String())
|
||||||
|
group2.Peers = append(group2.Peers, peerKey2.PublicKey().String())
|
||||||
|
|
||||||
|
err = manager.SaveGroup(account.Id, &group1)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expecting group1 to be added, got failure %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = manager.SaveGroup(account.Id, &group2)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expecting group2 to be added, got failure %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rule.Name = "test"
|
||||||
|
rule.Source = append(rule.Source, group1.ID)
|
||||||
|
rule.Destination = append(rule.Destination, group2.ID)
|
||||||
|
rule.Flow = TrafficFlowBidirect
|
||||||
|
err = manager.SaveRule(account.Id, &rule)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expecting rule to be added, got failure %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
networkMap1, err := manager.GetNetworkMap(peerKey1.PublicKey().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(networkMap1.Peers) != 1 {
|
||||||
|
t.Errorf(
|
||||||
|
"expecting Account NetworkMap to have 1 peers, got %v: %v",
|
||||||
|
len(networkMap1.Peers),
|
||||||
|
networkMap1.Peers,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if networkMap1.Peers[0].Key != peerKey2.PublicKey().String() {
|
||||||
|
t.Errorf(
|
||||||
|
"expecting Account NetworkMap to have peer with a key %s, got %s",
|
||||||
|
peerKey2.PublicKey().String(),
|
||||||
|
networkMap1.Peers[0].Key,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
networkMap2, err := manager.GetNetworkMap(peerKey2.PublicKey().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(networkMap2.Peers) != 1 {
|
||||||
|
t.Errorf("expecting Account NetworkMap to have 1 peers, got %v", len(networkMap2.Peers))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(networkMap2.Peers) > 0 && networkMap2.Peers[0].Key != peerKey1.PublicKey().String() {
|
||||||
|
t.Errorf(
|
||||||
|
"expecting Account NetworkMap to have peer with a key %s, got %s",
|
||||||
|
peerKey1.PublicKey().String(),
|
||||||
|
networkMap2.Peers[0].Key,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
107
management/server/rule.go
Normal file
107
management/server/rule.go
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TrafficFlowType defines allowed direction of the traffic in the rule
|
||||||
|
type TrafficFlowType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// TrafficFlowBidirect allows traffic to both direction
|
||||||
|
TrafficFlowBidirect TrafficFlowType = iota
|
||||||
|
)
|
||||||
|
|
||||||
|
// Rule of ACL for groups
|
||||||
|
type Rule struct {
|
||||||
|
// ID of the rule
|
||||||
|
ID string
|
||||||
|
|
||||||
|
// Name of the rule visible in the UI
|
||||||
|
Name string
|
||||||
|
|
||||||
|
// Source list of groups IDs of peers
|
||||||
|
Source []string
|
||||||
|
|
||||||
|
// Destination list of groups IDs of peers
|
||||||
|
Destination []string
|
||||||
|
|
||||||
|
// Flow of the traffic allowed by the rule
|
||||||
|
Flow TrafficFlowType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Rule) Copy() *Rule {
|
||||||
|
return &Rule{
|
||||||
|
ID: r.ID,
|
||||||
|
Name: r.Name,
|
||||||
|
Source: r.Source[:],
|
||||||
|
Destination: r.Destination[:],
|
||||||
|
Flow: r.Flow,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRule of ACL from the store
|
||||||
|
func (am *DefaultAccountManager) GetRule(accountID, ruleID string) (*Rule, error) {
|
||||||
|
am.mux.Lock()
|
||||||
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
rule, ok := account.Rules[ruleID]
|
||||||
|
if ok {
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, status.Errorf(codes.NotFound, "rule with ID %s not found", ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveRule of ACL in the store
|
||||||
|
func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error {
|
||||||
|
am.mux.Lock()
|
||||||
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(codes.NotFound, "account not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
account.Rules[rule.ID] = rule
|
||||||
|
return am.Store.SaveAccount(account)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRule of ACL from the store
|
||||||
|
func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error {
|
||||||
|
am.mux.Lock()
|
||||||
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(codes.NotFound, "account not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(account.Rules, ruleID)
|
||||||
|
|
||||||
|
return am.Store.SaveAccount(account)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListRules of ACL from the store
|
||||||
|
func (am *DefaultAccountManager) ListRules(accountID string) ([]*Rule, error) {
|
||||||
|
am.mux.Lock()
|
||||||
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := make([]*Rule, 0, len(account.Rules))
|
||||||
|
for _, item := range account.Rules {
|
||||||
|
rules = append(rules, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rules, nil
|
||||||
|
}
|
@ -4,10 +4,13 @@ type Store interface {
|
|||||||
GetPeer(peerKey string) (*Peer, error)
|
GetPeer(peerKey string) (*Peer, error)
|
||||||
DeletePeer(accountId string, peerKey string) (*Peer, error)
|
DeletePeer(accountId string, peerKey string) (*Peer, error)
|
||||||
SavePeer(accountId string, peer *Peer) error
|
SavePeer(accountId string, peer *Peer) error
|
||||||
|
GetAllAccounts() []*Account
|
||||||
GetAccount(accountId string) (*Account, error)
|
GetAccount(accountId string) (*Account, error)
|
||||||
GetUserAccount(userId string) (*Account, error)
|
GetUserAccount(userId string) (*Account, error)
|
||||||
GetAccountPeers(accountId string) ([]*Peer, error)
|
GetAccountPeers(accountId string) ([]*Peer, error)
|
||||||
GetPeerAccount(peerKey string) (*Account, error)
|
GetPeerAccount(peerKey string) (*Account, error)
|
||||||
|
GetPeerSrcRules(accountId, peerKey string) ([]*Rule, error)
|
||||||
|
GetPeerDstRules(accountId, peerKey string) ([]*Rule, error)
|
||||||
GetAccountBySetupKey(setupKey string) (*Account, error)
|
GetAccountBySetupKey(setupKey string) (*Account, error)
|
||||||
GetAccountByPrivateDomain(domain string) (*Account, error)
|
GetAccountByPrivateDomain(domain string) (*Account, error)
|
||||||
SaveAccount(account *Account) error
|
SaveAccount(account *Account) error
|
||||||
|
@ -58,6 +58,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
|
|||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
account = NewAccount(userId, lowerDomain)
|
account = NewAccount(userId, lowerDomain)
|
||||||
account.Users[userId] = NewAdminUser(userId)
|
account.Users[userId] = NewAdminUser(userId)
|
||||||
|
am.addAllGroup(account)
|
||||||
err = am.Store.SaveAccount(account)
|
err = am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "failed creating account")
|
return nil, status.Errorf(codes.Internal, "failed creating account")
|
||||||
|
Loading…
Reference in New Issue
Block a user