mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-07 08:44:07 +01:00
Add rules for ACL (#306)
Add rules HTTP endpoint for frontend - CRUD operations. Add Default rule - allow all. Send network map to peers based on rules.
This commit is contained in:
parent
11a3863c28
commit
3ce3ccc39a
@ -68,7 +68,10 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
}
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager()
|
||||
accountManager := mgmt.NewManager(store, peersUpdateManager, nil)
|
||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager)
|
||||
if err != nil {
|
||||
|
@ -455,7 +455,10 @@ func startManagement(port int, dataDir string) (*grpc.Server, error) {
|
||||
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
||||
}
|
||||
peersUpdateManager := server.NewPeersUpdateManager()
|
||||
accountManager := server.NewManager(store, peersUpdateManager, nil)
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager)
|
||||
if err != nil {
|
||||
|
@ -28,7 +28,6 @@ import (
|
||||
const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
|
||||
level, _ := log.ParseLevel("debug")
|
||||
log.SetLevel(level)
|
||||
|
||||
@ -56,7 +55,10 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
}
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager()
|
||||
accountManager := mgmt.NewManager(store, peersUpdateManager, nil)
|
||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager)
|
||||
if err != nil {
|
||||
@ -256,6 +258,7 @@ func TestClient_Sync(t *testing.T) {
|
||||
}
|
||||
if len(resp.GetRemotePeers()) != 1 {
|
||||
t.Errorf("expecting RemotePeers size %d got %d", 1, len(resp.GetRemotePeers()))
|
||||
return
|
||||
}
|
||||
if resp.GetRemotePeersIsEmpty() == true {
|
||||
t.Error("expecting RemotePeers property to be false, got true")
|
||||
@ -295,37 +298,36 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
mgmtMockServer.LoginFunc =
|
||||
func(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
peerKey, err := wgtypes.ParseKey(msg.GetWgPubKey())
|
||||
if err != nil {
|
||||
log.Warnf("error while parsing peer's Wireguard public key %s on Sync request.", msg.WgPubKey)
|
||||
return nil, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", msg.WgPubKey)
|
||||
}
|
||||
|
||||
loginReq := &proto.LoginRequest{}
|
||||
err = encryption.DecryptMessage(peerKey, serverKey, msg.Body, loginReq)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
actualMeta = loginReq.GetMeta()
|
||||
actualValidKey = loginReq.GetSetupKey()
|
||||
wg.Done()
|
||||
|
||||
loginResp := &proto.LoginResponse{}
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, serverKey, loginResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: serverKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
Version: 0,
|
||||
}, nil
|
||||
mgmtMockServer.LoginFunc = func(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
peerKey, err := wgtypes.ParseKey(msg.GetWgPubKey())
|
||||
if err != nil {
|
||||
log.Warnf("error while parsing peer's Wireguard public key %s on Sync request.", msg.WgPubKey)
|
||||
return nil, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", msg.WgPubKey)
|
||||
}
|
||||
|
||||
loginReq := &proto.LoginRequest{}
|
||||
err = encryption.DecryptMessage(peerKey, serverKey, msg.Body, loginReq)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
actualMeta = loginReq.GetMeta()
|
||||
actualValidKey = loginReq.GetSetupKey()
|
||||
wg.Done()
|
||||
|
||||
loginResp := &proto.LoginResponse{}
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, serverKey, loginResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: serverKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
Version: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
info := system.GetInfo()
|
||||
_, err = testClient.Register(*key, ValidKey, "", info)
|
||||
if err != nil {
|
||||
@ -370,21 +372,19 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
ProviderConfig: &proto.ProviderConfig{ClientID: "client"},
|
||||
}
|
||||
|
||||
mgmtMockServer.GetDeviceAuthorizationFlowFunc =
|
||||
func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: serverKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
Version: 0,
|
||||
}, nil
|
||||
mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: serverKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
Version: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey)
|
||||
if err != nil {
|
||||
t.Error("error while retrieving device auth flow information")
|
||||
|
@ -108,20 +108,23 @@ var (
|
||||
}
|
||||
}
|
||||
|
||||
accountManager := server.NewManager(store, peersUpdateManager, idpManager)
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager)
|
||||
if err != nil {
|
||||
log.Fatalln("failed build default manager: ", err)
|
||||
}
|
||||
|
||||
var opts []grpc.ServerOption
|
||||
|
||||
var httpServer *http.Server
|
||||
if config.HttpConfig.LetsEncryptDomain != "" {
|
||||
//automatically generate a new certificate with Let's Encrypt
|
||||
// automatically generate a new certificate with Let's Encrypt
|
||||
certManager := encryption.CreateCertManager(config.Datadir, config.HttpConfig.LetsEncryptDomain)
|
||||
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
|
||||
opts = append(opts, grpc.Creds(transportCredentials))
|
||||
|
||||
httpServer = http.NewHttpsServer(config.HttpConfig, certManager, accountManager)
|
||||
} else if config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "" {
|
||||
//use provided certificate
|
||||
// use provided certificate
|
||||
tlsConfig, err := loadTLSConfig(config.HttpConfig.CertFile, config.HttpConfig.CertKey)
|
||||
if err != nil {
|
||||
log.Fatal("cannot load TLS credentials: ", err)
|
||||
@ -130,7 +133,7 @@ var (
|
||||
opts = append(opts, grpc.Creds(transportCredentials))
|
||||
httpServer = http.NewHttpsServerWithTLSConfig(config.HttpConfig, tlsConfig, accountManager)
|
||||
} else {
|
||||
//start server without SSL
|
||||
// start server without SSL
|
||||
httpServer = http.NewHttpServer(config.HttpConfig, accountManager)
|
||||
}
|
||||
|
||||
@ -309,5 +312,4 @@ func init() {
|
||||
mgmtCmd.Flags().StringVar(&certFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
|
||||
mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
|
||||
rootCmd.MarkFlagRequired("config") //nolint
|
||||
|
||||
}
|
||||
|
@ -24,7 +24,12 @@ const (
|
||||
type AccountManager interface {
|
||||
GetOrCreateAccountByUser(userId, domain string) (*Account, error)
|
||||
GetAccountByUser(userId string) (*Account, error)
|
||||
AddSetupKey(accountId string, keyName string, keyType SetupKeyType, expiresIn *util.Duration) (*SetupKey, error)
|
||||
AddSetupKey(
|
||||
accountId string,
|
||||
keyName string,
|
||||
keyType SetupKeyType,
|
||||
expiresIn *util.Duration,
|
||||
) (*SetupKey, error)
|
||||
RevokeSetupKey(accountId string, keyId string) (*SetupKey, error)
|
||||
RenameSetupKey(accountId string, keyId string, newName string) (*SetupKey, error)
|
||||
GetAccountById(accountId string) (*Account, error)
|
||||
@ -47,6 +52,10 @@ type AccountManager interface {
|
||||
GroupAddPeer(accountId, groupID, peerKey string) error
|
||||
GroupDeletePeer(accountId, groupID, peerKey string) error
|
||||
GroupListPeers(accountId, groupID string) ([]*Peer, error)
|
||||
GetRule(accountId, ruleID string) (*Rule, error)
|
||||
SaveRule(accountID string, rule *Rule) error
|
||||
DeleteRule(accountId, ruleID string) error
|
||||
ListRules(accountId string) ([]*Rule, error)
|
||||
}
|
||||
|
||||
type DefaultAccountManager struct {
|
||||
@ -70,6 +79,7 @@ type Account struct {
|
||||
Peers map[string]*Peer
|
||||
Users map[string]*User
|
||||
Groups map[string]*Group
|
||||
Rules map[string]*Rule
|
||||
}
|
||||
|
||||
type UserInfo struct {
|
||||
@ -101,6 +111,16 @@ func (a *Account) Copy() *Account {
|
||||
setupKeys[id] = key.Copy()
|
||||
}
|
||||
|
||||
groups := map[string]*Group{}
|
||||
for id, group := range a.Groups {
|
||||
groups[id] = group.Copy()
|
||||
}
|
||||
|
||||
rules := map[string]*Rule{}
|
||||
for id, rule := range a.Rules {
|
||||
rules[id] = rule.Copy()
|
||||
}
|
||||
|
||||
return &Account{
|
||||
Id: a.Id,
|
||||
CreatedBy: a.CreatedBy,
|
||||
@ -108,17 +128,43 @@ func (a *Account) Copy() *Account {
|
||||
Network: a.Network.Copy(),
|
||||
Peers: peers,
|
||||
Users: users,
|
||||
Groups: groups,
|
||||
Rules: rules,
|
||||
}
|
||||
}
|
||||
|
||||
// NewManager creates a new DefaultAccountManager with a provided Store
|
||||
func NewManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager) *DefaultAccountManager {
|
||||
return &DefaultAccountManager{
|
||||
func (a *Account) GetGroupAll() (*Group, error) {
|
||||
for _, g := range a.Groups {
|
||||
if g.Name == "All" {
|
||||
return g, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no group ALL found")
|
||||
}
|
||||
|
||||
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||
func BuildManager(
|
||||
store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
|
||||
) (*DefaultAccountManager, error) {
|
||||
dam := &DefaultAccountManager{
|
||||
Store: store,
|
||||
mux: sync.Mutex{},
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
idpManager: idpManager,
|
||||
}
|
||||
|
||||
// if account has not default account
|
||||
// we build 'all' group and add all peers into it
|
||||
// also we create default rule with source an destination
|
||||
// groups 'all'
|
||||
for _, account := range store.GetAllAccounts() {
|
||||
dam.addAllGroup(account)
|
||||
if err := store.SaveAccount(account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return dam, nil
|
||||
}
|
||||
|
||||
// AddSetupKey generates a new setup key with a given name and type, and adds it to the specified account
|
||||
@ -223,7 +269,9 @@ func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, err
|
||||
|
||||
// GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and
|
||||
// user id doesn't have an account associated with it, one account is created
|
||||
func (am *DefaultAccountManager) GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error) {
|
||||
func (am *DefaultAccountManager) GetAccountByUserOrAccountId(
|
||||
userId, accountId, domain string,
|
||||
) (*Account, error) {
|
||||
if accountId != "" {
|
||||
return am.GetAccountById(accountId)
|
||||
} else if userId != "" {
|
||||
@ -490,6 +538,8 @@ func (am *DefaultAccountManager) AddAccount(accountId, userId, domain string) (*
|
||||
func (am *DefaultAccountManager) createAccount(accountId, userId, domain string) (*Account, error) {
|
||||
account := newAccountWithId(accountId, userId, domain)
|
||||
|
||||
am.addAllGroup(account)
|
||||
|
||||
err := am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed creating account")
|
||||
@ -498,6 +548,28 @@ func (am *DefaultAccountManager) createAccount(accountId, userId, domain string)
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// addAllGroup to account object it it doesn't exists
|
||||
func (am *DefaultAccountManager) addAllGroup(account *Account) {
|
||||
if len(account.Groups) == 0 {
|
||||
allGroup := &Group{
|
||||
ID: xid.New().String(),
|
||||
Name: "All",
|
||||
}
|
||||
for _, peer := range account.Peers {
|
||||
allGroup.Peers = append(allGroup.Peers, peer.Key)
|
||||
}
|
||||
account.Groups = map[string]*Group{allGroup.ID: allGroup}
|
||||
|
||||
defaultRule := &Rule{
|
||||
ID: xid.New().String(),
|
||||
Name: "Default",
|
||||
Source: []string{allGroup.ID},
|
||||
Destination: []string{allGroup.ID},
|
||||
}
|
||||
account.Rules = map[string]*Rule{defaultRule.ID: defaultRule}
|
||||
}
|
||||
}
|
||||
|
||||
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
||||
func newAccountWithId(accountId, userId, domain string) *Account {
|
||||
log.Debugf("creating new account")
|
||||
|
@ -37,7 +37,6 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
|
||||
|
||||
type initUserParams jwtclaims.AuthorizationClaims
|
||||
|
||||
type test struct {
|
||||
@ -165,7 +164,6 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
|
||||
}
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
@ -346,7 +344,6 @@ func TestAccountManager_AccountExists(t *testing.T) {
|
||||
if !*exists {
|
||||
t.Errorf("expected account to exist after creation, got false")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccount(t *testing.T) {
|
||||
@ -363,7 +360,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
//AddAccount has been already tested so we can assume it is correct and compare results
|
||||
// AddAccount has been already tested so we can assume it is correct and compare results
|
||||
getAccount, err := manager.GetAccountById(expectedId)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -385,7 +382,6 @@ func TestAccountManager_GetAccount(t *testing.T) {
|
||||
t.Errorf("expected account to have setup key %s, not found", key.Key)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAccountManager_AddPeer(t *testing.T) {
|
||||
@ -400,7 +396,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serial := account.Network.CurrentSerial() //should be 0
|
||||
serial := account.Network.CurrentSerial() // should be 0
|
||||
|
||||
var setupKey *SetupKey
|
||||
for _, key := range account.SetupKeys {
|
||||
@ -457,7 +453,6 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
if account.Network.CurrentSerial() != 1 {
|
||||
t.Errorf("expecting Network Serial=%d to be incremented by 1 and be equal to %d when adding new peer to account", serial, account.Network.CurrentSerial())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
@ -474,7 +469,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serial := account.Network.CurrentSerial() //should be 0
|
||||
serial := account.Network.CurrentSerial() // should be 0
|
||||
|
||||
if account.Network.Serial != 0 {
|
||||
t.Errorf("expecting account network to have an initial Serial=0")
|
||||
@ -521,7 +516,6 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
if account.Network.CurrentSerial() != 1 {
|
||||
t.Errorf("expecting Network Serial=%d to be incremented by 1 and be equal to %d when adding new peer to account", serial, account.Network.CurrentSerial())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
@ -573,7 +567,6 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
if account.Network.CurrentSerial() != 2 {
|
||||
t.Errorf("expecting Network Serial=%d to be incremented and be equal to 2 after adding and deleteing a peer", account.Network.CurrentSerial())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestGetUsersFromAccount(t *testing.T) {
|
||||
@ -614,7 +607,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewManager(store, NewPeersUpdateManager(), nil), nil
|
||||
return BuildManager(store, NewPeersUpdateManager(), nil)
|
||||
}
|
||||
|
||||
func createStore(t *testing.T) (Store, error) {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@ -18,10 +19,12 @@ const storeFileName = "store.json"
|
||||
// FileStore represents an account storage backed by a file persisted to disk
|
||||
type FileStore struct {
|
||||
Accounts map[string]*Account
|
||||
SetupKeyId2AccountId map[string]string `json:"-"`
|
||||
PeerKeyId2AccountId map[string]string `json:"-"`
|
||||
UserId2AccountId map[string]string `json:"-"`
|
||||
PrivateDomain2AccountId map[string]string `json:"-"`
|
||||
SetupKeyId2AccountId map[string]string `json:"-"`
|
||||
PeerKeyId2AccountId map[string]string `json:"-"`
|
||||
UserId2AccountId map[string]string `json:"-"`
|
||||
PrivateDomain2AccountId map[string]string `json:"-"`
|
||||
PeerKeyId2SrcRulesId map[string]map[string]struct{} `json:"-"`
|
||||
PeerKeyId2DstRulesId map[string]map[string]struct{} `json:"-"`
|
||||
|
||||
// mutex to synchronise Store read/write operations
|
||||
mux sync.Mutex `json:"-"`
|
||||
@ -47,6 +50,8 @@ func restore(file string) (*FileStore, error) {
|
||||
PeerKeyId2AccountId: make(map[string]string),
|
||||
UserId2AccountId: make(map[string]string),
|
||||
PrivateDomain2AccountId: make(map[string]string),
|
||||
PeerKeyId2SrcRulesId: make(map[string]map[string]struct{}),
|
||||
PeerKeyId2DstRulesId: make(map[string]map[string]struct{}),
|
||||
storeFile: file,
|
||||
}
|
||||
|
||||
@ -69,10 +74,39 @@ func restore(file string) (*FileStore, error) {
|
||||
store.PeerKeyId2AccountId = make(map[string]string)
|
||||
store.UserId2AccountId = make(map[string]string)
|
||||
store.PrivateDomain2AccountId = make(map[string]string)
|
||||
store.PeerKeyId2SrcRulesId = map[string]map[string]struct{}{}
|
||||
store.PeerKeyId2DstRulesId = map[string]map[string]struct{}{}
|
||||
|
||||
for accountId, account := range store.Accounts {
|
||||
for setupKeyId := range account.SetupKeys {
|
||||
store.SetupKeyId2AccountId[strings.ToUpper(setupKeyId)] = accountId
|
||||
}
|
||||
for _, rule := range account.Rules {
|
||||
for _, groupID := range rule.Source {
|
||||
if group, ok := account.Groups[groupID]; ok {
|
||||
for _, peerID := range group.Peers {
|
||||
rules := store.PeerKeyId2SrcRulesId[peerID]
|
||||
if rules == nil {
|
||||
rules = map[string]struct{}{}
|
||||
store.PeerKeyId2SrcRulesId[peerID] = rules
|
||||
}
|
||||
rules[rule.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, groupID := range rule.Destination {
|
||||
if group, ok := account.Groups[groupID]; ok {
|
||||
for _, peerID := range group.Peers {
|
||||
rules := store.PeerKeyId2DstRulesId[peerID]
|
||||
if rules == nil {
|
||||
rules = map[string]struct{}{}
|
||||
store.PeerKeyId2DstRulesId[peerID] = rules
|
||||
}
|
||||
rules[rule.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, peer := range account.Peers {
|
||||
store.PeerKeyId2AccountId[peer.Key] = accountId
|
||||
}
|
||||
@ -82,7 +116,8 @@ func restore(file string) (*FileStore, error) {
|
||||
for _, user := range account.Users {
|
||||
store.UserId2AccountId[user.Id] = accountId
|
||||
}
|
||||
if account.Domain != "" && account.DomainCategory == PrivateCategory && account.IsDomainPrimaryAccount {
|
||||
if account.Domain != "" && account.DomainCategory == PrivateCategory &&
|
||||
account.IsDomainPrimaryAccount {
|
||||
store.PrivateDomain2AccountId[account.Domain] = accountId
|
||||
}
|
||||
}
|
||||
@ -106,6 +141,24 @@ func (s *FileStore) SavePeer(accountId string, peer *Peer) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// if it is new peer, add it to default 'All' group
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ind := -1
|
||||
for i, pid := range allGroup.Peers {
|
||||
if pid == peer.Key {
|
||||
ind = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if ind < 0 {
|
||||
allGroup.Peers = append(allGroup.Peers, peer.Key)
|
||||
}
|
||||
|
||||
account.Peers[peer.Key] = peer
|
||||
return s.persist(s.storeFile)
|
||||
}
|
||||
@ -176,6 +229,29 @@ func (s *FileStore) SaveAccount(account *Account) error {
|
||||
s.PeerKeyId2AccountId[peer.Key] = account.Id
|
||||
}
|
||||
|
||||
for _, rule := range account.Rules {
|
||||
for _, gid := range rule.Source {
|
||||
for _, pid := range account.Groups[gid].Peers {
|
||||
rules := s.PeerKeyId2SrcRulesId[pid]
|
||||
if rules == nil {
|
||||
rules = map[string]struct{}{}
|
||||
s.PeerKeyId2SrcRulesId[pid] = rules
|
||||
}
|
||||
rules[rule.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
for _, gid := range rule.Destination {
|
||||
for _, pid := range account.Groups[gid].Peers {
|
||||
rules := s.PeerKeyId2DstRulesId[pid]
|
||||
if rules == nil {
|
||||
rules = map[string]struct{}{}
|
||||
s.PeerKeyId2DstRulesId[pid] = rules
|
||||
}
|
||||
rules[rule.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, user := range account.Users {
|
||||
s.UserId2AccountId[user.Id] = account.Id
|
||||
}
|
||||
@ -190,7 +266,10 @@ func (s *FileStore) SaveAccount(account *Account) error {
|
||||
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
||||
accountId, accountIdFound := s.PrivateDomain2AccountId[strings.ToLower(domain)]
|
||||
if !accountIdFound {
|
||||
return nil, status.Errorf(codes.NotFound, "provided domain is not registered or is not private")
|
||||
return nil, status.Errorf(
|
||||
codes.NotFound,
|
||||
"provided domain is not registered or is not private",
|
||||
)
|
||||
}
|
||||
|
||||
account, err := s.GetAccount(accountId)
|
||||
@ -232,6 +311,14 @@ func (s *FileStore) GetAccountPeers(accountId string) ([]*Peer, error) {
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAllAccounts() (all []*Account) {
|
||||
for _, a := range s.Accounts {
|
||||
all = append(all, a)
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccount(accountId string) (*Account, error) {
|
||||
account, accountFound := s.Accounts[accountId]
|
||||
if !accountFound {
|
||||
@ -265,18 +352,52 @@ func (s *FileStore) GetPeerAccount(peerKey string) (*Account, error) {
|
||||
return s.GetAccount(accountId)
|
||||
}
|
||||
|
||||
func (s *FileStore) GetGroup(groupID string) (*Group, error) {
|
||||
return nil, nil
|
||||
func (s *FileStore) GetPeerSrcRules(accountId, peerKey string) ([]*Rule, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
account, err := s.GetAccount(accountId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ruleIDs, ok := s.PeerKeyId2SrcRulesId[peerKey]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no rules for peer: %v", ruleIDs)
|
||||
}
|
||||
|
||||
rules := []*Rule{}
|
||||
for id := range ruleIDs {
|
||||
rule, ok := account.Rules[id]
|
||||
if ok {
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) SaveGroup(group *Group) error {
|
||||
return nil
|
||||
}
|
||||
func (s *FileStore) GetPeerDstRules(accountId, peerKey string) ([]*Rule, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
func (s *FileStore) DeleteGroup(groupID string) error {
|
||||
return nil
|
||||
}
|
||||
account, err := s.GetAccount(accountId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (s *FileStore) ListGroups() ([]*Group, error) {
|
||||
return nil, nil
|
||||
ruleIDs, ok := s.PeerKeyId2DstRulesId[peerKey]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no rules for peer: %v", ruleIDs)
|
||||
}
|
||||
|
||||
rules := []*Rule{}
|
||||
for id := range ruleIDs {
|
||||
rule, ok := account.Rules[id]
|
||||
if ok {
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
@ -17,6 +17,14 @@ type Group struct {
|
||||
Peers []string
|
||||
}
|
||||
|
||||
func (g *Group) Copy() *Group {
|
||||
return &Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Peers: g.Peers[:],
|
||||
}
|
||||
}
|
||||
|
||||
// GetGroup object of the peers
|
||||
func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) {
|
||||
am.mux.Lock()
|
||||
|
@ -3,11 +3,12 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
@ -64,7 +65,6 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
|
||||
}
|
||||
|
||||
func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
|
||||
|
||||
// todo introduce something more meaningful with the key expiration/rotation
|
||||
now := time.Now().Add(24 * time.Hour)
|
||||
secs := int64(now.Second())
|
||||
@ -77,10 +77,9 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser
|
||||
}, nil
|
||||
}
|
||||
|
||||
//Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||
|
||||
log.Debugf("Sync request from peer %s", req.WgPubKey)
|
||||
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
@ -155,7 +154,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
}
|
||||
|
||||
func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Peer, error) {
|
||||
|
||||
var (
|
||||
reqSetupKey string
|
||||
userId string
|
||||
@ -209,7 +207,7 @@ func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Pe
|
||||
return nil, status.Errorf(codes.NotFound, "provided setup key doesn't exists")
|
||||
}
|
||||
|
||||
//todo move to DefaultAccountManager the code below
|
||||
// todo move to DefaultAccountManager the code below
|
||||
networkMap, err := s.accountManager.GetNetworkMap(peer.Key)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "unable to fetch network map after registering peer, error: %v", err)
|
||||
@ -240,7 +238,6 @@ func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Pe
|
||||
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
|
||||
// In case of the successful registration login is also successful
|
||||
func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
|
||||
log.Debugf("Login request from peer %s", req.WgPubKey)
|
||||
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
@ -252,18 +249,18 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
peer, err := s.accountManager.GetPeer(peerKey.String())
|
||||
if err != nil {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.NotFound {
|
||||
//peer doesn't exist -> check if setup key was provided
|
||||
// peer doesn't exist -> check if setup key was provided
|
||||
loginReq := &proto.LoginRequest{}
|
||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, loginReq)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid request message")
|
||||
}
|
||||
if loginReq.GetJwtToken() == "" && loginReq.GetSetupKey() == "" {
|
||||
//absent setup key -> permission denied
|
||||
// absent setup key -> permission denied
|
||||
return nil, status.Errorf(codes.PermissionDenied, "provided peer with the key wgPubKey %s is not registered and no setup key or jwt was provided", peerKey.String())
|
||||
}
|
||||
|
||||
//setup key or jwt is present -> try normal registration flow
|
||||
// setup key or jwt is present -> try normal registration flow
|
||||
peer, err = s.registerPeer(peerKey, loginReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -303,13 +300,12 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
|
||||
case TCP:
|
||||
return proto.HostConfig_TCP
|
||||
default:
|
||||
//mbragin: todo something better?
|
||||
// mbragin: todo something better?
|
||||
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
|
||||
}
|
||||
}
|
||||
|
||||
func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig {
|
||||
|
||||
var stuns []*proto.HostConfig
|
||||
for _, stun := range config.Stuns {
|
||||
stuns = append(stuns, &proto.HostConfig{
|
||||
@ -350,26 +346,23 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
|
||||
|
||||
func toPeerConfig(peer *Peer) *proto.PeerConfig {
|
||||
return &proto.PeerConfig{
|
||||
Address: peer.IP.String() + "/24", //todo make it explicit
|
||||
Address: peer.IP.String() + "/24", // todo make it explicit
|
||||
}
|
||||
}
|
||||
|
||||
func toRemotePeerConfig(peers []*Peer) []*proto.RemotePeerConfig {
|
||||
|
||||
remotePeers := []*proto.RemotePeerConfig{}
|
||||
for _, rPeer := range peers {
|
||||
remotePeers = append(remotePeers, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)}, //todo /32
|
||||
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)}, // todo /32
|
||||
})
|
||||
}
|
||||
|
||||
return remotePeers
|
||||
|
||||
}
|
||||
|
||||
func toSyncResponse(config *Config, peer *Peer, peers []*Peer, turnCredentials *TURNCredentials, serial uint64) *proto.SyncResponse {
|
||||
|
||||
wtConfig := toWiretrusteeConfig(config, turnCredentials)
|
||||
|
||||
pConfig := toPeerConfig(peer)
|
||||
@ -397,7 +390,6 @@ func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty,
|
||||
|
||||
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
|
||||
func (s *Server) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.ManagementService_SyncServer) error {
|
||||
|
||||
networkMap, err := s.accountManager.GetNetworkMap(peer.Key)
|
||||
if err != nil {
|
||||
log.Warnf("error getting a list of peers for a peer %s", peer.Key)
|
||||
@ -436,7 +428,6 @@ func (s *Server) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.Mana
|
||||
// This is used for initiating an Oauth 2 device authorization grant flow
|
||||
// which will be used by our clients to Login
|
||||
func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
if err != nil {
|
||||
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)
|
||||
|
@ -13,20 +13,33 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Groups is a handler that returns groups of the account
|
||||
type Groups struct {
|
||||
accountManager server.AccountManager
|
||||
authAudience string
|
||||
jwtExtractor jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
// GroupResponse is a response sent to the client
|
||||
type GroupResponse struct {
|
||||
ID string
|
||||
Name string
|
||||
Peers []GroupPeerResponse `json:",omitempty"`
|
||||
}
|
||||
|
||||
// GroupPeerResponse is a response sent to the client
|
||||
type GroupPeerResponse struct {
|
||||
Key string
|
||||
Name string
|
||||
}
|
||||
|
||||
// GroupRequest to create or update group
|
||||
type GroupRequest struct {
|
||||
ID string
|
||||
Name string
|
||||
Peers []string
|
||||
}
|
||||
|
||||
// Groups is a handler that returns groups of the account
|
||||
type Groups struct {
|
||||
jwtExtractor jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
authAudience string
|
||||
}
|
||||
|
||||
func NewGroups(accountManager server.AccountManager, authAudience string) *Groups {
|
||||
return &Groups{
|
||||
accountManager: accountManager,
|
||||
@ -44,7 +57,12 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONObject(w, account.Groups)
|
||||
var groups []*GroupResponse
|
||||
for _, g := range account.Groups {
|
||||
groups = append(groups, toGroupResponse(account, g))
|
||||
}
|
||||
|
||||
writeJSONObject(w, groups)
|
||||
}
|
||||
|
||||
func (h *Groups) CreateOrUpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
@ -54,7 +72,7 @@ func (h *Groups) CreateOrUpdateGroupHandler(w http.ResponseWriter, r *http.Reque
|
||||
return
|
||||
}
|
||||
|
||||
var req server.Group
|
||||
var req GroupRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
@ -64,13 +82,19 @@ func (h *Groups) CreateOrUpdateGroupHandler(w http.ResponseWriter, r *http.Reque
|
||||
req.ID = xid.New().String()
|
||||
}
|
||||
|
||||
if err := h.accountManager.SaveGroup(account.Id, &req); err != nil {
|
||||
group := server.Group{
|
||||
ID: req.ID,
|
||||
Name: req.Name,
|
||||
Peers: req.Peers,
|
||||
}
|
||||
|
||||
if err := h.accountManager.SaveGroup(account.Id, &group); err != nil {
|
||||
log.Errorf("failed updating group %s under account %s %v", req.ID, account.Id, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONObject(w, &req)
|
||||
writeJSONObject(w, toGroupResponse(account, &group))
|
||||
}
|
||||
|
||||
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
@ -117,7 +141,7 @@ func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONObject(w, group)
|
||||
writeJSONObject(w, toGroupResponse(account, group))
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
}
|
||||
@ -133,3 +157,29 @@ func (h *Groups) getGroupAccount(r *http.Request) (*server.Account, error) {
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func toGroupResponse(account *server.Account, group *server.Group) *GroupResponse {
|
||||
cache := make(map[string]GroupPeerResponse)
|
||||
gr := GroupResponse{
|
||||
ID: group.ID,
|
||||
Name: group.Name,
|
||||
}
|
||||
|
||||
for _, pid := range group.Peers {
|
||||
peerResp, ok := cache[pid]
|
||||
if !ok {
|
||||
peer, ok := account.Peers[pid]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
peerResp = GroupPeerResponse{
|
||||
Key: peer.Key,
|
||||
Name: peer.Name,
|
||||
}
|
||||
cache[pid] = peerResp
|
||||
}
|
||||
gr.Peers = append(gr.Peers, peerResp)
|
||||
}
|
||||
|
||||
return &gr
|
||||
}
|
||||
|
211
management/server/http/handler/rules.go
Normal file
211
management/server/http/handler/rules.go
Normal file
@ -0,0 +1,211 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const FlowBidirectString = "bidirect"
|
||||
|
||||
// RuleResponse is a response sent to the client
|
||||
type RuleResponse struct {
|
||||
ID string
|
||||
Name string
|
||||
Source []RuleGroupResponse
|
||||
Destination []RuleGroupResponse
|
||||
Flow string
|
||||
}
|
||||
|
||||
// RuleGroupResponse is a response sent to the client
|
||||
type RuleGroupResponse struct {
|
||||
ID string
|
||||
Name string
|
||||
PeersCount int
|
||||
}
|
||||
|
||||
// RuleRequest to create or update rule
|
||||
type RuleRequest struct {
|
||||
ID string
|
||||
Name string
|
||||
Source []string
|
||||
Destination []string
|
||||
Flow string
|
||||
}
|
||||
|
||||
// Rules is a handler that returns rules of the account
|
||||
type Rules struct {
|
||||
jwtExtractor jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
authAudience string
|
||||
}
|
||||
|
||||
func NewRules(accountManager server.AccountManager, authAudience string) *Rules {
|
||||
return &Rules{
|
||||
accountManager: accountManager,
|
||||
authAudience: authAudience,
|
||||
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllRulesHandler list for the account
|
||||
func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, err := h.getRuleAccount(r)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var rules []*RuleResponse
|
||||
for _, r := range account.Rules {
|
||||
rules = append(rules, toRuleResponse(account, r))
|
||||
}
|
||||
|
||||
writeJSONObject(w, rules)
|
||||
}
|
||||
|
||||
func (h *Rules) CreateOrUpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, err := h.getRuleAccount(r)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var req RuleRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method == http.MethodPost {
|
||||
req.ID = xid.New().String()
|
||||
}
|
||||
|
||||
rule := server.Rule{
|
||||
ID: req.ID,
|
||||
Name: req.Name,
|
||||
Source: req.Source,
|
||||
Destination: req.Destination,
|
||||
}
|
||||
|
||||
switch req.Flow {
|
||||
case FlowBidirectString:
|
||||
rule.Flow = server.TrafficFlowBidirect
|
||||
default:
|
||||
http.Error(w, "unknown flow type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.accountManager.SaveRule(account.Id, &rule); err != nil {
|
||||
log.Errorf("failed updating rule %s under account %s %v", req.ID, account.Id, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONObject(w, &req)
|
||||
}
|
||||
|
||||
func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, err := h.getRuleAccount(r)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
aID := account.Id
|
||||
|
||||
rID := mux.Vars(r)["id"]
|
||||
if len(rID) == 0 {
|
||||
http.Error(w, "invalid rule ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.accountManager.DeleteRule(aID, rID); err != nil {
|
||||
log.Errorf("failed delete rule %s under account %s %v", rID, aID, err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONObject(w, "")
|
||||
}
|
||||
|
||||
func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
account, err := h.getRuleAccount(r)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
ruleID := mux.Vars(r)["id"]
|
||||
if len(ruleID) == 0 {
|
||||
http.Error(w, "invalid rule ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
rule, err := h.accountManager.GetRule(account.Id, ruleID)
|
||||
if err != nil {
|
||||
http.Error(w, "rule not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONObject(w, toRuleResponse(account, rule))
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Rules) getRuleAccount(r *http.Request) (*server.Account, error) {
|
||||
jwtClaims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
|
||||
|
||||
account, err := h.accountManager.GetAccountWithAuthorizationClaims(jwtClaims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func toRuleResponse(account *server.Account, rule *server.Rule) *RuleResponse {
|
||||
gr := RuleResponse{
|
||||
ID: rule.ID,
|
||||
Name: rule.Name,
|
||||
}
|
||||
|
||||
switch rule.Flow {
|
||||
case server.TrafficFlowBidirect:
|
||||
gr.Flow = FlowBidirectString
|
||||
default:
|
||||
gr.Flow = "unknown"
|
||||
}
|
||||
|
||||
for _, gid := range rule.Source {
|
||||
if group, ok := account.Groups[gid]; ok {
|
||||
gr.Source = append(gr.Source, RuleGroupResponse{
|
||||
ID: group.ID,
|
||||
Name: group.Name,
|
||||
PeersCount: len(group.Peers),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, gid := range rule.Destination {
|
||||
if group, ok := account.Groups[gid]; ok {
|
||||
gr.Destination = append(gr.Destination, RuleGroupResponse{
|
||||
ID: group.ID,
|
||||
Name: group.Name,
|
||||
PeersCount: len(group.Peers),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return &gr
|
||||
}
|
211
management/server/http/handler/rules_test.go
Normal file
211
management/server/http/handler/rules_test.go
Normal file
@ -0,0 +1,211 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
|
||||
"github.com/magiconair/properties/assert"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
|
||||
func initRulesTestData(rules ...*server.Rule) *Rules {
|
||||
return &Rules{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
SaveRuleFunc: func(_ string, rule *server.Rule) error {
|
||||
if !strings.HasPrefix(rule.ID, "id-") {
|
||||
rule.ID = "id-was-set"
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetRuleFunc: func(_, ruleID string) (*server.Rule, error) {
|
||||
if ruleID != "idoftherule" {
|
||||
return nil, fmt.Errorf("not found")
|
||||
}
|
||||
return &server.Rule{
|
||||
ID: "idoftherule",
|
||||
Name: "Rule",
|
||||
Source: []string{"idofsrcrule"},
|
||||
Destination: []string{"idofdestrule"},
|
||||
Flow: server.TrafficFlowBidirect,
|
||||
}, nil
|
||||
},
|
||||
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
authAudience: "",
|
||||
jwtExtractor: jwtclaims.ClaimsExtractor{
|
||||
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: "test_user",
|
||||
Domain: "hotmail.com",
|
||||
AccountId: "test_id",
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestRulesGetRule(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
expectedBody bool
|
||||
requestType string
|
||||
requestPath string
|
||||
requestBody io.Reader
|
||||
}{
|
||||
{
|
||||
name: "GetRule OK",
|
||||
expectedBody: true,
|
||||
requestType: http.MethodGet,
|
||||
requestPath: "/api/rules/idoftherule",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "GetRule not found",
|
||||
requestType: http.MethodGet,
|
||||
requestPath: "/api/rules/notexists",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
rule := &server.Rule{
|
||||
ID: "idoftherule",
|
||||
Name: "Rule",
|
||||
}
|
||||
|
||||
p := initRulesTestData(rule)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/rules/{id}", p.GetRuleHandler).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if status := recorder.Code; status != tc.expectedStatus {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v",
|
||||
status, tc.expectedStatus)
|
||||
return
|
||||
}
|
||||
|
||||
if !tc.expectedBody {
|
||||
return
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("I don't know what I expected; %v", err)
|
||||
}
|
||||
|
||||
var got RuleResponse
|
||||
if err = json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, got.ID, rule.ID)
|
||||
assert.Equal(t, got.Name, rule.Name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRulesSaveRule(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
expectedBody bool
|
||||
expectedRule *server.Rule
|
||||
requestType string
|
||||
requestPath string
|
||||
requestBody io.Reader
|
||||
}{
|
||||
{
|
||||
name: "SaveRule POST OK",
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/rules",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(`{"Name":"Default POSTed Rule","Flow":"bidirect"}`)),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRule: &server.Rule{
|
||||
ID: "id-was-set",
|
||||
Name: "Default POSTed Rule",
|
||||
Flow: server.TrafficFlowBidirect,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SaveRule PUT OK",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/rules",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(`{"ID":"id-existed","Name":"Default POSTed Rule","Flow":"bidirect"}`)),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedRule: &server.Rule{
|
||||
ID: "id-existed",
|
||||
Name: "Default POSTed Rule",
|
||||
Flow: server.TrafficFlowBidirect,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
p := initRulesTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/rules", p.CreateOrUpdateRuleHandler).Methods("PUT", "POST")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
content, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("I don't know what I expected; %v", err)
|
||||
}
|
||||
|
||||
if status := recorder.Code; status != tc.expectedStatus {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
|
||||
status, tc.expectedStatus, string(content))
|
||||
return
|
||||
}
|
||||
|
||||
if !tc.expectedBody {
|
||||
return
|
||||
}
|
||||
|
||||
got := &RuleRequest{}
|
||||
if err = json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
if tc.requestType != http.MethodPost {
|
||||
assert.Equal(t, got.ID, tc.expectedRule.ID)
|
||||
}
|
||||
assert.Equal(t, got.Name, tc.expectedRule.Name)
|
||||
assert.Equal(t, got.Flow, "bidirect")
|
||||
})
|
||||
}
|
||||
}
|
@ -96,6 +96,7 @@ func (s *Server) Start() error {
|
||||
r.Use(jwtMiddleware.Handler, corsMiddleware.Handler)
|
||||
|
||||
groupsHandler := handler.NewGroups(s.accountManager, s.config.AuthAudience)
|
||||
rulesHandler := handler.NewRules(s.accountManager, s.config.AuthAudience)
|
||||
peersHandler := handler.NewPeers(s.accountManager, s.config.AuthAudience)
|
||||
keysHandler := handler.NewSetupKeysHandler(s.accountManager, s.config.AuthAudience)
|
||||
r.HandleFunc("/api/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS")
|
||||
@ -112,6 +113,12 @@ func (s *Server) Start() error {
|
||||
r.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).
|
||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||
|
||||
r.HandleFunc("/api/rules", rulesHandler.GetAllRulesHandler).Methods("GET", "OPTIONS")
|
||||
r.HandleFunc("/api/rules", rulesHandler.CreateOrUpdateRuleHandler).
|
||||
Methods("POST", "PUT", "OPTIONS")
|
||||
r.HandleFunc("/api/rules/{id}", rulesHandler.GetRuleHandler).Methods("GET", "OPTIONS")
|
||||
r.HandleFunc("/api/rules/{id}", rulesHandler.DeleteRuleHandler).Methods("DELETE", "OPTIONS")
|
||||
|
||||
r.HandleFunc("/api/groups", groupsHandler.GetAllGroupsHandler).Methods("GET", "OPTIONS")
|
||||
r.HandleFunc("/api/groups", groupsHandler.CreateOrUpdateGroupHandler).
|
||||
Methods("POST", "PUT", "OPTIONS")
|
||||
|
@ -3,6 +3,13 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@ -11,12 +18,6 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -39,8 +40,7 @@ const (
|
||||
|
||||
// registerPeers registers peersNum peers on the management service and returns their Wireguard keys
|
||||
func registerPeers(peersNum int, client mgmtProto.ManagementServiceClient) ([]*wgtypes.Key, error) {
|
||||
|
||||
var peers = []*wgtypes.Key{}
|
||||
peers := []*wgtypes.Key{}
|
||||
for i := 0; i < peersNum; i++ {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
@ -60,7 +60,6 @@ func registerPeers(peersNum int, client mgmtProto.ManagementServiceClient) ([]*w
|
||||
|
||||
// getServerKey gets Management Service Wireguard public key
|
||||
func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error) {
|
||||
|
||||
keyResp, err := client.GetServerKey(context.TODO(), &mgmtProto.Empty{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -75,7 +74,6 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error
|
||||
}
|
||||
|
||||
func Test_SyncProtocol(t *testing.T) {
|
||||
|
||||
dir := t.TempDir()
|
||||
err := util.CopyFileContents("testdata/store.json", filepath.Join(dir, "store.json"))
|
||||
if err != nil {
|
||||
@ -263,7 +261,6 @@ func Test_SyncProtocol(t *testing.T) {
|
||||
}
|
||||
|
||||
func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServiceClient) (*mgmtProto.LoginResponse, error) {
|
||||
|
||||
serverKey, err := getServerKey(client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -298,11 +295,9 @@ func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServ
|
||||
}
|
||||
|
||||
return loginResp, nil
|
||||
|
||||
}
|
||||
|
||||
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
|
||||
testingServerKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
|
||||
@ -362,7 +357,6 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
mgmtServer := &Server{
|
||||
wgKey: testingServerKey,
|
||||
config: &Config{
|
||||
@ -397,7 +391,6 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, port int, config *Config) (*grpc.Server, error) {
|
||||
|
||||
lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -408,7 +401,10 @@ func startManagement(t *testing.T, port int, config *Config) (*grpc.Server, erro
|
||||
return nil, err
|
||||
}
|
||||
peersUpdateManager := NewPeersUpdateManager()
|
||||
accountManager := NewManager(store, peersUpdateManager, nil)
|
||||
accountManager, err := BuildManager(store, peersUpdateManager, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager)
|
||||
if err != nil {
|
||||
|
@ -2,8 +2,6 @@ package server_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
server "github.com/netbirdio/netbird/management/server"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net"
|
||||
@ -13,6 +11,9 @@ import (
|
||||
sync2 "sync"
|
||||
"time"
|
||||
|
||||
server "github.com/netbirdio/netbird/management/server"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
pb "github.com/golang/protobuf/proto" //nolint
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@ -31,7 +32,6 @@ const (
|
||||
)
|
||||
|
||||
var _ = Describe("Management service", func() {
|
||||
|
||||
var (
|
||||
addr string
|
||||
s *grpc.Server
|
||||
@ -66,7 +66,6 @@ var _ = Describe("Management service", func() {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
serverPubKey, err = wgtypes.ParseKey(resp.Key)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
@ -78,7 +77,6 @@ var _ = Describe("Management service", func() {
|
||||
|
||||
Context("when calling IsHealthy endpoint", func() {
|
||||
Specify("a non-error result is returned", func() {
|
||||
|
||||
healthy, err := client.IsHealthy(context.TODO(), &mgmtProto.Empty{})
|
||||
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
@ -87,7 +85,6 @@ var _ = Describe("Management service", func() {
|
||||
})
|
||||
|
||||
Context("when calling Sync endpoint", func() {
|
||||
|
||||
Context("when there is a new peer registered", func() {
|
||||
Specify("a proper configuration is returned", func() {
|
||||
key, _ := wgtypes.GenerateKey()
|
||||
@ -168,7 +165,6 @@ var _ = Describe("Management service", func() {
|
||||
Expect(resp.GetRemotePeers()).To(HaveLen(2))
|
||||
peers := []string{resp.GetRemotePeers()[0].WgPubKey, resp.GetRemotePeers()[1].WgPubKey}
|
||||
Expect(peers).To(ContainElements(key1.PublicKey().String(), key2.PublicKey().String()))
|
||||
|
||||
})
|
||||
})
|
||||
|
||||
@ -211,7 +207,6 @@ var _ = Describe("Management service", func() {
|
||||
resp = &mgmtProto.SyncResponse{}
|
||||
err = pb.Unmarshal(decryptedBytes, resp)
|
||||
wg.Done()
|
||||
|
||||
}()
|
||||
|
||||
// register a new peer
|
||||
@ -229,7 +224,6 @@ var _ = Describe("Management service", func() {
|
||||
|
||||
Context("when calling GetServerKey endpoint", func() {
|
||||
Specify("a public Wireguard key of the service is returned", func() {
|
||||
|
||||
resp, err := client.GetServerKey(context.TODO(), &mgmtProto.Empty{})
|
||||
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
@ -237,19 +231,16 @@ var _ = Describe("Management service", func() {
|
||||
Expect(resp.Key).ToNot(BeNil())
|
||||
Expect(resp.ExpiresAt).ToNot(BeNil())
|
||||
|
||||
//check if the key is a valid Wireguard key
|
||||
// check if the key is a valid Wireguard key
|
||||
key, err := wgtypes.ParseKey(resp.Key)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(key).ToNot(BeNil())
|
||||
|
||||
})
|
||||
})
|
||||
|
||||
Context("when calling Login endpoint", func() {
|
||||
|
||||
Context("with an invalid setup key", func() {
|
||||
Specify("an error is returned", func() {
|
||||
|
||||
key, _ := wgtypes.GenerateKey()
|
||||
message, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.LoginRequest{SetupKey: "invalid setup key"})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
@ -261,24 +252,20 @@ var _ = Describe("Management service", func() {
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(resp).To(BeNil())
|
||||
|
||||
})
|
||||
})
|
||||
|
||||
Context("with a valid setup key", func() {
|
||||
It("a non error result is returned", func() {
|
||||
|
||||
key, _ := wgtypes.GenerateKey()
|
||||
resp := loginPeerWithValidSetupKey(serverPubKey, key, client)
|
||||
|
||||
Expect(resp).ToNot(BeNil())
|
||||
|
||||
})
|
||||
})
|
||||
|
||||
Context("with a registered peer", func() {
|
||||
It("a non error result is returned", func() {
|
||||
|
||||
key, _ := wgtypes.GenerateKey()
|
||||
regResp := loginPeerWithValidSetupKey(serverPubKey, key, client)
|
||||
Expect(regResp).NotTo(BeNil())
|
||||
@ -324,7 +311,6 @@ var _ = Describe("Management service", func() {
|
||||
Context("when there are 50 peers registered under one account", func() {
|
||||
Context("when there are 10 more peers registered under the same account", func() {
|
||||
Specify("all of the 50 peers will get updates of 10 newly registered peers", func() {
|
||||
|
||||
initialPeers := 20
|
||||
additionalPeers := 10
|
||||
|
||||
@ -369,7 +355,7 @@ var _ = Describe("Management service", func() {
|
||||
err = pb.Unmarshal(decryptedBytes, resp)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
if len(resp.GetRemotePeers()) > 0 {
|
||||
//only consider peer updates
|
||||
// only consider peer updates
|
||||
wg.Done()
|
||||
}
|
||||
}
|
||||
@ -397,7 +383,6 @@ var _ = Describe("Management service", func() {
|
||||
|
||||
Context("when there are peers registered under one account concurrently", func() {
|
||||
Specify("then there are no duplicate IPs", func() {
|
||||
|
||||
initialPeers := 30
|
||||
|
||||
ipChannel := make(chan string, 20)
|
||||
@ -423,7 +408,6 @@ var _ = Describe("Management service", func() {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
ipChannel <- resp.GetPeerConfig().Address
|
||||
|
||||
}()
|
||||
}
|
||||
|
||||
@ -443,6 +427,7 @@ var _ = Describe("Management service", func() {
|
||||
})
|
||||
|
||||
func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse {
|
||||
defer GinkgoRecover()
|
||||
|
||||
meta := &mgmtProto.PeerSystemMeta{
|
||||
Hostname: key.PublicKey().String(),
|
||||
@ -467,7 +452,6 @@ func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, clien
|
||||
err = encryption.DecryptMessage(serverPubKey, key, resp.Body, loginResp)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
return loginResp
|
||||
|
||||
}
|
||||
|
||||
func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn) {
|
||||
@ -496,7 +480,10 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
|
||||
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
||||
}
|
||||
peersUpdateManager := server.NewPeersUpdateManager()
|
||||
accountManager := server.NewManager(store, peersUpdateManager, nil)
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil)
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating a manager: %v", err)
|
||||
}
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
@ -33,6 +33,10 @@ type MockAccountManager struct {
|
||||
GroupAddPeerFunc func(accountID, groupID, peerKey string) error
|
||||
GroupDeletePeerFunc func(accountID, groupID, peerKey string) error
|
||||
GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error)
|
||||
GetRuleFunc func(accountID, ruleID string) (*server.Rule, error)
|
||||
SaveRuleFunc func(accountID string, rule *server.Rule) error
|
||||
DeleteRuleFunc func(accountID, ruleID string) error
|
||||
ListRulesFunc func(accountID string) ([]*server.Rule, error)
|
||||
GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error)
|
||||
}
|
||||
|
||||
@ -41,7 +45,6 @@ func (am *MockAccountManager) GetUsersFromAccount(accountID string) ([]*server.U
|
||||
return am.GetUsersFromAccountFunc(accountID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount not implemented")
|
||||
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetOrCreateAccountByUser(
|
||||
@ -207,7 +210,7 @@ func (am *MockAccountManager) SaveGroup(accountID string, group *server.Group) e
|
||||
if am.SaveGroupFunc != nil {
|
||||
return am.SaveGroupFunc(accountID, group)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method UpdateGroup not implemented")
|
||||
return status.Errorf(codes.Unimplemented, "method SaveGroup not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) DeleteGroup(accountID, groupID string) error {
|
||||
@ -244,3 +247,31 @@ func (am *MockAccountManager) GroupListPeers(accountID, groupID string) ([]*serv
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GroupListPeers not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) GetRule(accountID, ruleID string) (*server.Rule, error) {
|
||||
if am.GetRuleFunc != nil {
|
||||
return am.GetRuleFunc(accountID, ruleID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetRule not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SaveRule(accountID string, rule *server.Rule) error {
|
||||
if am.SaveRuleFunc != nil {
|
||||
return am.SaveRuleFunc(accountID, rule)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SaveRule not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) DeleteRule(accountID, ruleID string) error {
|
||||
if am.DeleteRuleFunc != nil {
|
||||
return am.DeleteRuleFunc(accountID, ruleID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteRule not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) ListRules(accountID string) ([]*server.Rule, error) {
|
||||
if am.ListRulesFunc != nil {
|
||||
return am.ListRulesFunc(accountID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListRules not implemented")
|
||||
}
|
||||
|
@ -1,12 +1,13 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// PeerSystemMeta is a metadata of a Peer machine system
|
||||
@ -21,31 +22,31 @@ type PeerSystemMeta struct {
|
||||
}
|
||||
|
||||
type PeerStatus struct {
|
||||
//LastSeen is the last time peer was connected to the management service
|
||||
// LastSeen is the last time peer was connected to the management service
|
||||
LastSeen time.Time
|
||||
//Connected indicates whether peer is connected to the management service or not
|
||||
// Connected indicates whether peer is connected to the management service or not
|
||||
Connected bool
|
||||
}
|
||||
|
||||
//Peer represents a machine connected to the network.
|
||||
//The Peer is a Wireguard peer identified by a public key
|
||||
// Peer represents a machine connected to the network.
|
||||
// The Peer is a Wireguard peer identified by a public key
|
||||
type Peer struct {
|
||||
//Wireguard public key
|
||||
// Wireguard public key
|
||||
Key string
|
||||
//A setup key this peer was registered with
|
||||
// A setup key this peer was registered with
|
||||
SetupKey string
|
||||
//IP address of the Peer
|
||||
// IP address of the Peer
|
||||
IP net.IP
|
||||
//Meta is a Peer system meta data
|
||||
// Meta is a Peer system meta data
|
||||
Meta PeerSystemMeta
|
||||
//Name is peer's name (machine name)
|
||||
// Name is peer's name (machine name)
|
||||
Name string
|
||||
Status *PeerStatus
|
||||
//The user ID that registered the peer
|
||||
// The user ID that registered the peer
|
||||
UserID string
|
||||
}
|
||||
|
||||
//Copy copies Peer object
|
||||
// Copy copies Peer object
|
||||
func (p *Peer) Copy() *Peer {
|
||||
return &Peer{
|
||||
Key: p.Key,
|
||||
@ -58,7 +59,7 @@ func (p *Peer) Copy() *Peer {
|
||||
}
|
||||
}
|
||||
|
||||
//GetPeer returns a peer from a Store
|
||||
// GetPeer returns a peer from a Store
|
||||
func (am *DefaultAccountManager) GetPeer(peerKey string) (*Peer, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
@ -71,7 +72,7 @@ func (am *DefaultAccountManager) GetPeer(peerKey string) (*Peer, error) {
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
//MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(peerKey string, connected bool) error {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
@ -96,8 +97,12 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerKey string, connected boo
|
||||
return nil
|
||||
}
|
||||
|
||||
//RenamePeer changes peer's name
|
||||
func (am *DefaultAccountManager) RenamePeer(accountId string, peerKey string, newName string) (*Peer, error) {
|
||||
// RenamePeer changes peer's name
|
||||
func (am *DefaultAccountManager) RenamePeer(
|
||||
accountId string,
|
||||
peerKey string,
|
||||
newName string,
|
||||
) (*Peer, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
@ -116,7 +121,7 @@ func (am *DefaultAccountManager) RenamePeer(accountId string, peerKey string, ne
|
||||
return peerCopy, nil
|
||||
}
|
||||
|
||||
//DeletePeer removes peer from the account by it's IP
|
||||
// DeletePeer removes peer from the account by it's IP
|
||||
func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*Peer, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
@ -149,12 +154,13 @@ func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
},
|
||||
}})
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//notify other peers of the change
|
||||
// notify other peers of the change
|
||||
peers, err := am.Store.GetAccountPeers(accountId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -180,7 +186,8 @@ func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*
|
||||
RemotePeers: update,
|
||||
RemotePeersIsEmpty: len(update) == 0,
|
||||
},
|
||||
}})
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -190,7 +197,7 @@ func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
//GetPeerByIP returns peer by it's IP
|
||||
// GetPeerByIP returns peer by it's IP
|
||||
func (am *DefaultAccountManager) GetPeerByIP(accountId string, peerIP string) (*Peer, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
@ -220,10 +227,46 @@ func (am *DefaultAccountManager) GetNetworkMap(peerKey string) (*NetworkMap, err
|
||||
}
|
||||
|
||||
var res []*Peer
|
||||
for _, peer := range account.Peers {
|
||||
// exclude original peer
|
||||
if peer.Key != peerKey {
|
||||
res = append(res, peer.Copy())
|
||||
srcRules, err := am.Store.GetPeerSrcRules(account.Id, peerKey)
|
||||
if err != nil {
|
||||
return &NetworkMap{
|
||||
Peers: res,
|
||||
Network: account.Network.Copy(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
dstRules, err := am.Store.GetPeerDstRules(account.Id, peerKey)
|
||||
if err != nil {
|
||||
return &NetworkMap{
|
||||
Peers: res,
|
||||
Network: account.Network.Copy(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
groups := map[string]*Group{}
|
||||
for _, r := range srcRules {
|
||||
if r.Flow == TrafficFlowBidirect {
|
||||
for _, gid := range r.Destination {
|
||||
groups[gid] = account.Groups[gid]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range dstRules {
|
||||
if r.Flow == TrafficFlowBidirect {
|
||||
for _, gid := range r.Source {
|
||||
groups[gid] = account.Groups[gid]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, g := range groups {
|
||||
for _, pid := range g.Peers {
|
||||
peer := account.Peers[pid]
|
||||
// exclude original peer
|
||||
if peer.Key != peerKey {
|
||||
res = append(res, peer.Copy())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -240,7 +283,11 @@ func (am *DefaultAccountManager) GetNetworkMap(peerKey string) (*NetworkMap, err
|
||||
// to it. We also add the User ID to the peer metadata to identify registrant.
|
||||
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
|
||||
// The peer property is just a placeholder for the Peer properties to pass further
|
||||
func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *Peer) (*Peer, error) {
|
||||
func (am *DefaultAccountManager) AddPeer(
|
||||
setupKey string,
|
||||
userID string,
|
||||
peer *Peer,
|
||||
) (*Peer, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
@ -252,17 +299,28 @@ func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *P
|
||||
if len(upperKey) != 0 {
|
||||
account, err = am.Store.GetAccountBySetupKey(upperKey)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "unable to register peer, unable to find account with setupKey %s", upperKey)
|
||||
return nil, status.Errorf(
|
||||
codes.NotFound,
|
||||
"unable to register peer, unable to find account with setupKey %s",
|
||||
upperKey,
|
||||
)
|
||||
}
|
||||
|
||||
sk = getAccountSetupKeyByKey(account, upperKey)
|
||||
if sk == nil {
|
||||
// shouldn't happen actually
|
||||
return nil, status.Errorf(codes.NotFound, "unable to register peer, unknown setupKey %s", upperKey)
|
||||
return nil, status.Errorf(
|
||||
codes.NotFound,
|
||||
"unable to register peer, unknown setupKey %s",
|
||||
upperKey,
|
||||
)
|
||||
}
|
||||
|
||||
if !sk.IsValid() {
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "unable to register peer, its setup key is invalid (expired, overused or revoked)")
|
||||
return nil, status.Errorf(
|
||||
codes.FailedPrecondition,
|
||||
"unable to register peer, its setup key is invalid (expired, overused or revoked)",
|
||||
)
|
||||
}
|
||||
|
||||
} else if len(userID) != 0 {
|
||||
@ -293,6 +351,13 @@ func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *P
|
||||
Status: &PeerStatus{Connected: false, LastSeen: time.Now()},
|
||||
}
|
||||
|
||||
// add peer to 'All' group
|
||||
group, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
group.Peers = append(group.Peers, newPeer.Key)
|
||||
|
||||
account.Peers[newPeer.Key] = newPeer
|
||||
if len(upperKey) != 0 {
|
||||
account.SetupKeys[sk.Key] = sk.IncrementUsage()
|
||||
@ -305,5 +370,4 @@ func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *P
|
||||
}
|
||||
|
||||
return newPeer, nil
|
||||
|
||||
}
|
||||
|
@ -1,8 +1,10 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
@ -70,7 +72,151 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
if networkMap.Peers[0].Key != peerKey2.PublicKey().String() {
|
||||
t.Errorf("expecting Account NetworkMap to have peer with a key %s, got %s", peerKey2.PublicKey().String(), networkMap.Peers[0].Key)
|
||||
t.Errorf(
|
||||
"expecting Account NetworkMap to have peer with a key %s, got %s",
|
||||
peerKey2.PublicKey().String(),
|
||||
networkMap.Peers[0].Key,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountManager_GetNetworkMapWithRule(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
expectedId := "test_account"
|
||||
userId := "account_creator"
|
||||
account, err := manager.AddAccount(expectedId, userId, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var setupKey *SetupKey
|
||||
for _, key := range account.SetupKeys {
|
||||
if key.Type == SetupKeyReusable {
|
||||
setupKey = key
|
||||
}
|
||||
}
|
||||
|
||||
peerKey1, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = manager.AddPeer(setupKey.Key, "", &Peer{
|
||||
Key: peerKey1.PublicKey().String(),
|
||||
Meta: PeerSystemMeta{},
|
||||
Name: "test-peer-2",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expecting peer to be added, got failure %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
peerKey2, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
_, err = manager.AddPeer(setupKey.Key, "", &Peer{
|
||||
Key: peerKey2.PublicKey().String(),
|
||||
Meta: PeerSystemMeta{},
|
||||
Name: "test-peer-2",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expecting peer to be added, got failure %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
rules, err := manager.ListRules(account.Id)
|
||||
if err != nil {
|
||||
t.Errorf("expecting to get a list of rules, got failure %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = manager.DeleteRule(account.Id, rules[0].ID)
|
||||
if err != nil {
|
||||
t.Errorf("expecting to delete 1 group, got failure %v", err)
|
||||
return
|
||||
}
|
||||
var (
|
||||
group1 Group
|
||||
group2 Group
|
||||
rule Rule
|
||||
)
|
||||
|
||||
group1.ID = xid.New().String()
|
||||
group2.ID = xid.New().String()
|
||||
group1.Name = "src"
|
||||
group2.Name = "dst"
|
||||
rule.ID = xid.New().String()
|
||||
group1.Peers = append(group1.Peers, peerKey1.PublicKey().String())
|
||||
group2.Peers = append(group2.Peers, peerKey2.PublicKey().String())
|
||||
|
||||
err = manager.SaveGroup(account.Id, &group1)
|
||||
if err != nil {
|
||||
t.Errorf("expecting group1 to be added, got failure %v", err)
|
||||
return
|
||||
}
|
||||
err = manager.SaveGroup(account.Id, &group2)
|
||||
if err != nil {
|
||||
t.Errorf("expecting group2 to be added, got failure %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
rule.Name = "test"
|
||||
rule.Source = append(rule.Source, group1.ID)
|
||||
rule.Destination = append(rule.Destination, group2.ID)
|
||||
rule.Flow = TrafficFlowBidirect
|
||||
err = manager.SaveRule(account.Id, &rule)
|
||||
if err != nil {
|
||||
t.Errorf("expecting rule to be added, got failure %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
networkMap1, err := manager.GetNetworkMap(peerKey1.PublicKey().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(networkMap1.Peers) != 1 {
|
||||
t.Errorf(
|
||||
"expecting Account NetworkMap to have 1 peers, got %v: %v",
|
||||
len(networkMap1.Peers),
|
||||
networkMap1.Peers,
|
||||
)
|
||||
}
|
||||
|
||||
if networkMap1.Peers[0].Key != peerKey2.PublicKey().String() {
|
||||
t.Errorf(
|
||||
"expecting Account NetworkMap to have peer with a key %s, got %s",
|
||||
peerKey2.PublicKey().String(),
|
||||
networkMap1.Peers[0].Key,
|
||||
)
|
||||
}
|
||||
|
||||
networkMap2, err := manager.GetNetworkMap(peerKey2.PublicKey().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(networkMap2.Peers) != 1 {
|
||||
t.Errorf("expecting Account NetworkMap to have 1 peers, got %v", len(networkMap2.Peers))
|
||||
}
|
||||
|
||||
if len(networkMap2.Peers) > 0 && networkMap2.Peers[0].Key != peerKey1.PublicKey().String() {
|
||||
t.Errorf(
|
||||
"expecting Account NetworkMap to have peer with a key %s, got %s",
|
||||
peerKey1.PublicKey().String(),
|
||||
networkMap2.Peers[0].Key,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
107
management/server/rule.go
Normal file
107
management/server/rule.go
Normal file
@ -0,0 +1,107 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// TrafficFlowType defines allowed direction of the traffic in the rule
|
||||
type TrafficFlowType int
|
||||
|
||||
const (
|
||||
// TrafficFlowBidirect allows traffic to both direction
|
||||
TrafficFlowBidirect TrafficFlowType = iota
|
||||
)
|
||||
|
||||
// Rule of ACL for groups
|
||||
type Rule struct {
|
||||
// ID of the rule
|
||||
ID string
|
||||
|
||||
// Name of the rule visible in the UI
|
||||
Name string
|
||||
|
||||
// Source list of groups IDs of peers
|
||||
Source []string
|
||||
|
||||
// Destination list of groups IDs of peers
|
||||
Destination []string
|
||||
|
||||
// Flow of the traffic allowed by the rule
|
||||
Flow TrafficFlowType
|
||||
}
|
||||
|
||||
func (r *Rule) Copy() *Rule {
|
||||
return &Rule{
|
||||
ID: r.ID,
|
||||
Name: r.Name,
|
||||
Source: r.Source[:],
|
||||
Destination: r.Destination[:],
|
||||
Flow: r.Flow,
|
||||
}
|
||||
}
|
||||
|
||||
// GetRule of ACL from the store
|
||||
func (am *DefaultAccountManager) GetRule(accountID, ruleID string) (*Rule, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
}
|
||||
|
||||
rule, ok := account.Rules[ruleID]
|
||||
if ok {
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.NotFound, "rule with ID %s not found", ruleID)
|
||||
}
|
||||
|
||||
// SaveRule of ACL in the store
|
||||
func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.NotFound, "account not found")
|
||||
}
|
||||
|
||||
account.Rules[rule.ID] = rule
|
||||
return am.Store.SaveAccount(account)
|
||||
}
|
||||
|
||||
// DeleteRule of ACL from the store
|
||||
func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.NotFound, "account not found")
|
||||
}
|
||||
|
||||
delete(account.Rules, ruleID)
|
||||
|
||||
return am.Store.SaveAccount(account)
|
||||
}
|
||||
|
||||
// ListRules of ACL from the store
|
||||
func (am *DefaultAccountManager) ListRules(accountID string) ([]*Rule, error) {
|
||||
am.mux.Lock()
|
||||
defer am.mux.Unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||
}
|
||||
|
||||
rules := make([]*Rule, 0, len(account.Rules))
|
||||
for _, item := range account.Rules {
|
||||
rules = append(rules, item)
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
@ -4,10 +4,13 @@ type Store interface {
|
||||
GetPeer(peerKey string) (*Peer, error)
|
||||
DeletePeer(accountId string, peerKey string) (*Peer, error)
|
||||
SavePeer(accountId string, peer *Peer) error
|
||||
GetAllAccounts() []*Account
|
||||
GetAccount(accountId string) (*Account, error)
|
||||
GetUserAccount(userId string) (*Account, error)
|
||||
GetAccountPeers(accountId string) ([]*Peer, error)
|
||||
GetPeerAccount(peerKey string) (*Account, error)
|
||||
GetPeerSrcRules(accountId, peerKey string) ([]*Rule, error)
|
||||
GetPeerDstRules(accountId, peerKey string) ([]*Rule, error)
|
||||
GetAccountBySetupKey(setupKey string) (*Account, error)
|
||||
GetAccountByPrivateDomain(domain string) (*Account, error)
|
||||
SaveAccount(account *Account) error
|
||||
|
@ -58,6 +58,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
account = NewAccount(userId, lowerDomain)
|
||||
account.Users[userId] = NewAdminUser(userId)
|
||||
am.addAllGroup(account)
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed creating account")
|
||||
|
Loading…
Reference in New Issue
Block a user