Add rules for ACL (#306)

Add rules HTTP endpoint for frontend - CRUD operations.
Add Default rule - allow all.
Send network map to peers based on rules.
This commit is contained in:
Givi Khojanashvili 2022-05-21 17:21:39 +04:00 committed by GitHub
parent 11a3863c28
commit 3ce3ccc39a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1197 additions and 190 deletions

View File

@ -68,7 +68,10 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
}
peersUpdateManager := mgmt.NewPeersUpdateManager()
accountManager := mgmt.NewManager(store, peersUpdateManager, nil)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil)
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager)
if err != nil {

View File

@ -455,7 +455,10 @@ func startManagement(port int, dataDir string) (*grpc.Server, error) {
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
}
peersUpdateManager := server.NewPeersUpdateManager()
accountManager := server.NewManager(store, peersUpdateManager, nil)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil)
if err != nil {
return nil, err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager)
if err != nil {

View File

@ -28,7 +28,6 @@ import (
const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
level, _ := log.ParseLevel("debug")
log.SetLevel(level)
@ -56,7 +55,10 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
}
peersUpdateManager := mgmt.NewPeersUpdateManager()
accountManager := mgmt.NewManager(store, peersUpdateManager, nil)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil)
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager)
if err != nil {
@ -256,6 +258,7 @@ func TestClient_Sync(t *testing.T) {
}
if len(resp.GetRemotePeers()) != 1 {
t.Errorf("expecting RemotePeers size %d got %d", 1, len(resp.GetRemotePeers()))
return
}
if resp.GetRemotePeersIsEmpty() == true {
t.Error("expecting RemotePeers property to be false, got true")
@ -295,37 +298,36 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
mgmtMockServer.LoginFunc =
func(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
peerKey, err := wgtypes.ParseKey(msg.GetWgPubKey())
if err != nil {
log.Warnf("error while parsing peer's Wireguard public key %s on Sync request.", msg.WgPubKey)
return nil, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", msg.WgPubKey)
}
loginReq := &proto.LoginRequest{}
err = encryption.DecryptMessage(peerKey, serverKey, msg.Body, loginReq)
if err != nil {
log.Fatal(err)
}
actualMeta = loginReq.GetMeta()
actualValidKey = loginReq.GetSetupKey()
wg.Done()
loginResp := &proto.LoginResponse{}
encryptedResp, err := encryption.EncryptMessage(peerKey, serverKey, loginResp)
if err != nil {
return nil, err
}
return &mgmtProto.EncryptedMessage{
WgPubKey: serverKey.PublicKey().String(),
Body: encryptedResp,
Version: 0,
}, nil
mgmtMockServer.LoginFunc = func(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
peerKey, err := wgtypes.ParseKey(msg.GetWgPubKey())
if err != nil {
log.Warnf("error while parsing peer's Wireguard public key %s on Sync request.", msg.WgPubKey)
return nil, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", msg.WgPubKey)
}
loginReq := &proto.LoginRequest{}
err = encryption.DecryptMessage(peerKey, serverKey, msg.Body, loginReq)
if err != nil {
log.Fatal(err)
}
actualMeta = loginReq.GetMeta()
actualValidKey = loginReq.GetSetupKey()
wg.Done()
loginResp := &proto.LoginResponse{}
encryptedResp, err := encryption.EncryptMessage(peerKey, serverKey, loginResp)
if err != nil {
return nil, err
}
return &mgmtProto.EncryptedMessage{
WgPubKey: serverKey.PublicKey().String(),
Body: encryptedResp,
Version: 0,
}, nil
}
info := system.GetInfo()
_, err = testClient.Register(*key, ValidKey, "", info)
if err != nil {
@ -370,21 +372,19 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
ProviderConfig: &proto.ProviderConfig{ClientID: "client"},
}
mgmtMockServer.GetDeviceAuthorizationFlowFunc =
func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*proto.EncryptedMessage, error) {
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
if err != nil {
return nil, err
}
return &mgmtProto.EncryptedMessage{
WgPubKey: serverKey.PublicKey().String(),
Body: encryptedResp,
Version: 0,
}, nil
mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*proto.EncryptedMessage, error) {
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
if err != nil {
return nil, err
}
return &mgmtProto.EncryptedMessage{
WgPubKey: serverKey.PublicKey().String(),
Body: encryptedResp,
Version: 0,
}, nil
}
flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey)
if err != nil {
t.Error("error while retrieving device auth flow information")

View File

@ -108,20 +108,23 @@ var (
}
}
accountManager := server.NewManager(store, peersUpdateManager, idpManager)
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager)
if err != nil {
log.Fatalln("failed build default manager: ", err)
}
var opts []grpc.ServerOption
var httpServer *http.Server
if config.HttpConfig.LetsEncryptDomain != "" {
//automatically generate a new certificate with Let's Encrypt
// automatically generate a new certificate with Let's Encrypt
certManager := encryption.CreateCertManager(config.Datadir, config.HttpConfig.LetsEncryptDomain)
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
opts = append(opts, grpc.Creds(transportCredentials))
httpServer = http.NewHttpsServer(config.HttpConfig, certManager, accountManager)
} else if config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "" {
//use provided certificate
// use provided certificate
tlsConfig, err := loadTLSConfig(config.HttpConfig.CertFile, config.HttpConfig.CertKey)
if err != nil {
log.Fatal("cannot load TLS credentials: ", err)
@ -130,7 +133,7 @@ var (
opts = append(opts, grpc.Creds(transportCredentials))
httpServer = http.NewHttpsServerWithTLSConfig(config.HttpConfig, tlsConfig, accountManager)
} else {
//start server without SSL
// start server without SSL
httpServer = http.NewHttpServer(config.HttpConfig, accountManager)
}
@ -309,5 +312,4 @@ func init() {
mgmtCmd.Flags().StringVar(&certFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
rootCmd.MarkFlagRequired("config") //nolint
}

View File

@ -24,7 +24,12 @@ const (
type AccountManager interface {
GetOrCreateAccountByUser(userId, domain string) (*Account, error)
GetAccountByUser(userId string) (*Account, error)
AddSetupKey(accountId string, keyName string, keyType SetupKeyType, expiresIn *util.Duration) (*SetupKey, error)
AddSetupKey(
accountId string,
keyName string,
keyType SetupKeyType,
expiresIn *util.Duration,
) (*SetupKey, error)
RevokeSetupKey(accountId string, keyId string) (*SetupKey, error)
RenameSetupKey(accountId string, keyId string, newName string) (*SetupKey, error)
GetAccountById(accountId string) (*Account, error)
@ -47,6 +52,10 @@ type AccountManager interface {
GroupAddPeer(accountId, groupID, peerKey string) error
GroupDeletePeer(accountId, groupID, peerKey string) error
GroupListPeers(accountId, groupID string) ([]*Peer, error)
GetRule(accountId, ruleID string) (*Rule, error)
SaveRule(accountID string, rule *Rule) error
DeleteRule(accountId, ruleID string) error
ListRules(accountId string) ([]*Rule, error)
}
type DefaultAccountManager struct {
@ -70,6 +79,7 @@ type Account struct {
Peers map[string]*Peer
Users map[string]*User
Groups map[string]*Group
Rules map[string]*Rule
}
type UserInfo struct {
@ -101,6 +111,16 @@ func (a *Account) Copy() *Account {
setupKeys[id] = key.Copy()
}
groups := map[string]*Group{}
for id, group := range a.Groups {
groups[id] = group.Copy()
}
rules := map[string]*Rule{}
for id, rule := range a.Rules {
rules[id] = rule.Copy()
}
return &Account{
Id: a.Id,
CreatedBy: a.CreatedBy,
@ -108,17 +128,43 @@ func (a *Account) Copy() *Account {
Network: a.Network.Copy(),
Peers: peers,
Users: users,
Groups: groups,
Rules: rules,
}
}
// NewManager creates a new DefaultAccountManager with a provided Store
func NewManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager) *DefaultAccountManager {
return &DefaultAccountManager{
func (a *Account) GetGroupAll() (*Group, error) {
for _, g := range a.Groups {
if g.Name == "All" {
return g, nil
}
}
return nil, fmt.Errorf("no group ALL found")
}
// BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(
store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
) (*DefaultAccountManager, error) {
dam := &DefaultAccountManager{
Store: store,
mux: sync.Mutex{},
peersUpdateManager: peersUpdateManager,
idpManager: idpManager,
}
// if account has not default account
// we build 'all' group and add all peers into it
// also we create default rule with source an destination
// groups 'all'
for _, account := range store.GetAllAccounts() {
dam.addAllGroup(account)
if err := store.SaveAccount(account); err != nil {
return nil, err
}
}
return dam, nil
}
// AddSetupKey generates a new setup key with a given name and type, and adds it to the specified account
@ -223,7 +269,9 @@ func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, err
// GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and
// user id doesn't have an account associated with it, one account is created
func (am *DefaultAccountManager) GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error) {
func (am *DefaultAccountManager) GetAccountByUserOrAccountId(
userId, accountId, domain string,
) (*Account, error) {
if accountId != "" {
return am.GetAccountById(accountId)
} else if userId != "" {
@ -490,6 +538,8 @@ func (am *DefaultAccountManager) AddAccount(accountId, userId, domain string) (*
func (am *DefaultAccountManager) createAccount(accountId, userId, domain string) (*Account, error) {
account := newAccountWithId(accountId, userId, domain)
am.addAllGroup(account)
err := am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed creating account")
@ -498,6 +548,28 @@ func (am *DefaultAccountManager) createAccount(accountId, userId, domain string)
return account, nil
}
// addAllGroup to account object it it doesn't exists
func (am *DefaultAccountManager) addAllGroup(account *Account) {
if len(account.Groups) == 0 {
allGroup := &Group{
ID: xid.New().String(),
Name: "All",
}
for _, peer := range account.Peers {
allGroup.Peers = append(allGroup.Peers, peer.Key)
}
account.Groups = map[string]*Group{allGroup.ID: allGroup}
defaultRule := &Rule{
ID: xid.New().String(),
Name: "Default",
Source: []string{allGroup.ID},
Destination: []string{allGroup.ID},
}
account.Rules = map[string]*Rule{defaultRule.ID: defaultRule}
}
}
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
func newAccountWithId(accountId, userId, domain string) *Account {
log.Debugf("creating new account")

View File

@ -37,7 +37,6 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
}
func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
type initUserParams jwtclaims.AuthorizationClaims
type test struct {
@ -165,7 +164,6 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6} {
t.Run(testCase.name, func(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
@ -346,7 +344,6 @@ func TestAccountManager_AccountExists(t *testing.T) {
if !*exists {
t.Errorf("expected account to exist after creation, got false")
}
}
func TestAccountManager_GetAccount(t *testing.T) {
@ -363,7 +360,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
t.Fatal(err)
}
//AddAccount has been already tested so we can assume it is correct and compare results
// AddAccount has been already tested so we can assume it is correct and compare results
getAccount, err := manager.GetAccountById(expectedId)
if err != nil {
t.Fatal(err)
@ -385,7 +382,6 @@ func TestAccountManager_GetAccount(t *testing.T) {
t.Errorf("expected account to have setup key %s, not found", key.Key)
}
}
}
func TestAccountManager_AddPeer(t *testing.T) {
@ -400,7 +396,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
t.Fatal(err)
}
serial := account.Network.CurrentSerial() //should be 0
serial := account.Network.CurrentSerial() // should be 0
var setupKey *SetupKey
for _, key := range account.SetupKeys {
@ -457,7 +453,6 @@ func TestAccountManager_AddPeer(t *testing.T) {
if account.Network.CurrentSerial() != 1 {
t.Errorf("expecting Network Serial=%d to be incremented by 1 and be equal to %d when adding new peer to account", serial, account.Network.CurrentSerial())
}
}
func TestAccountManager_AddPeerWithUserID(t *testing.T) {
@ -474,7 +469,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
t.Fatal(err)
}
serial := account.Network.CurrentSerial() //should be 0
serial := account.Network.CurrentSerial() // should be 0
if account.Network.Serial != 0 {
t.Errorf("expecting account network to have an initial Serial=0")
@ -521,7 +516,6 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
if account.Network.CurrentSerial() != 1 {
t.Errorf("expecting Network Serial=%d to be incremented by 1 and be equal to %d when adding new peer to account", serial, account.Network.CurrentSerial())
}
}
func TestAccountManager_DeletePeer(t *testing.T) {
@ -573,7 +567,6 @@ func TestAccountManager_DeletePeer(t *testing.T) {
if account.Network.CurrentSerial() != 2 {
t.Errorf("expecting Network Serial=%d to be incremented and be equal to 2 after adding and deleteing a peer", account.Network.CurrentSerial())
}
}
func TestGetUsersFromAccount(t *testing.T) {
@ -614,7 +607,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
if err != nil {
return nil, err
}
return NewManager(store, NewPeersUpdateManager(), nil), nil
return BuildManager(store, NewPeersUpdateManager(), nil)
}
func createStore(t *testing.T) (Store, error) {

View File

@ -1,6 +1,7 @@
package server
import (
"fmt"
"os"
"path/filepath"
"strings"
@ -18,10 +19,12 @@ const storeFileName = "store.json"
// FileStore represents an account storage backed by a file persisted to disk
type FileStore struct {
Accounts map[string]*Account
SetupKeyId2AccountId map[string]string `json:"-"`
PeerKeyId2AccountId map[string]string `json:"-"`
UserId2AccountId map[string]string `json:"-"`
PrivateDomain2AccountId map[string]string `json:"-"`
SetupKeyId2AccountId map[string]string `json:"-"`
PeerKeyId2AccountId map[string]string `json:"-"`
UserId2AccountId map[string]string `json:"-"`
PrivateDomain2AccountId map[string]string `json:"-"`
PeerKeyId2SrcRulesId map[string]map[string]struct{} `json:"-"`
PeerKeyId2DstRulesId map[string]map[string]struct{} `json:"-"`
// mutex to synchronise Store read/write operations
mux sync.Mutex `json:"-"`
@ -47,6 +50,8 @@ func restore(file string) (*FileStore, error) {
PeerKeyId2AccountId: make(map[string]string),
UserId2AccountId: make(map[string]string),
PrivateDomain2AccountId: make(map[string]string),
PeerKeyId2SrcRulesId: make(map[string]map[string]struct{}),
PeerKeyId2DstRulesId: make(map[string]map[string]struct{}),
storeFile: file,
}
@ -69,10 +74,39 @@ func restore(file string) (*FileStore, error) {
store.PeerKeyId2AccountId = make(map[string]string)
store.UserId2AccountId = make(map[string]string)
store.PrivateDomain2AccountId = make(map[string]string)
store.PeerKeyId2SrcRulesId = map[string]map[string]struct{}{}
store.PeerKeyId2DstRulesId = map[string]map[string]struct{}{}
for accountId, account := range store.Accounts {
for setupKeyId := range account.SetupKeys {
store.SetupKeyId2AccountId[strings.ToUpper(setupKeyId)] = accountId
}
for _, rule := range account.Rules {
for _, groupID := range rule.Source {
if group, ok := account.Groups[groupID]; ok {
for _, peerID := range group.Peers {
rules := store.PeerKeyId2SrcRulesId[peerID]
if rules == nil {
rules = map[string]struct{}{}
store.PeerKeyId2SrcRulesId[peerID] = rules
}
rules[rule.ID] = struct{}{}
}
}
}
for _, groupID := range rule.Destination {
if group, ok := account.Groups[groupID]; ok {
for _, peerID := range group.Peers {
rules := store.PeerKeyId2DstRulesId[peerID]
if rules == nil {
rules = map[string]struct{}{}
store.PeerKeyId2DstRulesId[peerID] = rules
}
rules[rule.ID] = struct{}{}
}
}
}
}
for _, peer := range account.Peers {
store.PeerKeyId2AccountId[peer.Key] = accountId
}
@ -82,7 +116,8 @@ func restore(file string) (*FileStore, error) {
for _, user := range account.Users {
store.UserId2AccountId[user.Id] = accountId
}
if account.Domain != "" && account.DomainCategory == PrivateCategory && account.IsDomainPrimaryAccount {
if account.Domain != "" && account.DomainCategory == PrivateCategory &&
account.IsDomainPrimaryAccount {
store.PrivateDomain2AccountId[account.Domain] = accountId
}
}
@ -106,6 +141,24 @@ func (s *FileStore) SavePeer(accountId string, peer *Peer) error {
return err
}
// if it is new peer, add it to default 'All' group
allGroup, err := account.GetGroupAll()
if err != nil {
return err
}
ind := -1
for i, pid := range allGroup.Peers {
if pid == peer.Key {
ind = i
break
}
}
if ind < 0 {
allGroup.Peers = append(allGroup.Peers, peer.Key)
}
account.Peers[peer.Key] = peer
return s.persist(s.storeFile)
}
@ -176,6 +229,29 @@ func (s *FileStore) SaveAccount(account *Account) error {
s.PeerKeyId2AccountId[peer.Key] = account.Id
}
for _, rule := range account.Rules {
for _, gid := range rule.Source {
for _, pid := range account.Groups[gid].Peers {
rules := s.PeerKeyId2SrcRulesId[pid]
if rules == nil {
rules = map[string]struct{}{}
s.PeerKeyId2SrcRulesId[pid] = rules
}
rules[rule.ID] = struct{}{}
}
}
for _, gid := range rule.Destination {
for _, pid := range account.Groups[gid].Peers {
rules := s.PeerKeyId2DstRulesId[pid]
if rules == nil {
rules = map[string]struct{}{}
s.PeerKeyId2DstRulesId[pid] = rules
}
rules[rule.ID] = struct{}{}
}
}
}
for _, user := range account.Users {
s.UserId2AccountId[user.Id] = account.Id
}
@ -190,7 +266,10 @@ func (s *FileStore) SaveAccount(account *Account) error {
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
accountId, accountIdFound := s.PrivateDomain2AccountId[strings.ToLower(domain)]
if !accountIdFound {
return nil, status.Errorf(codes.NotFound, "provided domain is not registered or is not private")
return nil, status.Errorf(
codes.NotFound,
"provided domain is not registered or is not private",
)
}
account, err := s.GetAccount(accountId)
@ -232,6 +311,14 @@ func (s *FileStore) GetAccountPeers(accountId string) ([]*Peer, error) {
return peers, nil
}
func (s *FileStore) GetAllAccounts() (all []*Account) {
for _, a := range s.Accounts {
all = append(all, a)
}
return all
}
func (s *FileStore) GetAccount(accountId string) (*Account, error) {
account, accountFound := s.Accounts[accountId]
if !accountFound {
@ -265,18 +352,52 @@ func (s *FileStore) GetPeerAccount(peerKey string) (*Account, error) {
return s.GetAccount(accountId)
}
func (s *FileStore) GetGroup(groupID string) (*Group, error) {
return nil, nil
func (s *FileStore) GetPeerSrcRules(accountId, peerKey string) ([]*Rule, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, err := s.GetAccount(accountId)
if err != nil {
return nil, err
}
ruleIDs, ok := s.PeerKeyId2SrcRulesId[peerKey]
if !ok {
return nil, fmt.Errorf("no rules for peer: %v", ruleIDs)
}
rules := []*Rule{}
for id := range ruleIDs {
rule, ok := account.Rules[id]
if ok {
rules = append(rules, rule)
}
}
return rules, nil
}
func (s *FileStore) SaveGroup(group *Group) error {
return nil
}
func (s *FileStore) GetPeerDstRules(accountId, peerKey string) ([]*Rule, error) {
s.mux.Lock()
defer s.mux.Unlock()
func (s *FileStore) DeleteGroup(groupID string) error {
return nil
}
account, err := s.GetAccount(accountId)
if err != nil {
return nil, err
}
func (s *FileStore) ListGroups() ([]*Group, error) {
return nil, nil
ruleIDs, ok := s.PeerKeyId2DstRulesId[peerKey]
if !ok {
return nil, fmt.Errorf("no rules for peer: %v", ruleIDs)
}
rules := []*Rule{}
for id := range ruleIDs {
rule, ok := account.Rules[id]
if ok {
rules = append(rules, rule)
}
}
return rules, nil
}

View File

@ -17,6 +17,14 @@ type Group struct {
Peers []string
}
func (g *Group) Copy() *Group {
return &Group{
ID: g.ID,
Name: g.Name,
Peers: g.Peers[:],
}
}
// GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) {
am.mux.Lock()

View File

@ -3,11 +3,12 @@ package server
import (
"context"
"fmt"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"strings"
"time"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
@ -64,7 +65,6 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
}
func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
// todo introduce something more meaningful with the key expiration/rotation
now := time.Now().Add(24 * time.Hour)
secs := int64(now.Second())
@ -77,10 +77,9 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser
}, nil
}
//Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
log.Debugf("Sync request from peer %s", req.WgPubKey)
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
@ -155,7 +154,6 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
}
func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Peer, error) {
var (
reqSetupKey string
userId string
@ -209,7 +207,7 @@ func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Pe
return nil, status.Errorf(codes.NotFound, "provided setup key doesn't exists")
}
//todo move to DefaultAccountManager the code below
// todo move to DefaultAccountManager the code below
networkMap, err := s.accountManager.GetNetworkMap(peer.Key)
if err != nil {
return nil, status.Errorf(codes.Internal, "unable to fetch network map after registering peer, error: %v", err)
@ -240,7 +238,6 @@ func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Pe
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
// In case of the successful registration login is also successful
func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.Debugf("Login request from peer %s", req.WgPubKey)
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
@ -252,18 +249,18 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
peer, err := s.accountManager.GetPeer(peerKey.String())
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.NotFound {
//peer doesn't exist -> check if setup key was provided
// peer doesn't exist -> check if setup key was provided
loginReq := &proto.LoginRequest{}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, loginReq)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid request message")
}
if loginReq.GetJwtToken() == "" && loginReq.GetSetupKey() == "" {
//absent setup key -> permission denied
// absent setup key -> permission denied
return nil, status.Errorf(codes.PermissionDenied, "provided peer with the key wgPubKey %s is not registered and no setup key or jwt was provided", peerKey.String())
}
//setup key or jwt is present -> try normal registration flow
// setup key or jwt is present -> try normal registration flow
peer, err = s.registerPeer(peerKey, loginReq)
if err != nil {
return nil, err
@ -303,13 +300,12 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
case TCP:
return proto.HostConfig_TCP
default:
//mbragin: todo something better?
// mbragin: todo something better?
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
}
}
func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig {
var stuns []*proto.HostConfig
for _, stun := range config.Stuns {
stuns = append(stuns, &proto.HostConfig{
@ -350,26 +346,23 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
func toPeerConfig(peer *Peer) *proto.PeerConfig {
return &proto.PeerConfig{
Address: peer.IP.String() + "/24", //todo make it explicit
Address: peer.IP.String() + "/24", // todo make it explicit
}
}
func toRemotePeerConfig(peers []*Peer) []*proto.RemotePeerConfig {
remotePeers := []*proto.RemotePeerConfig{}
for _, rPeer := range peers {
remotePeers = append(remotePeers, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)}, //todo /32
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)}, // todo /32
})
}
return remotePeers
}
func toSyncResponse(config *Config, peer *Peer, peers []*Peer, turnCredentials *TURNCredentials, serial uint64) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials)
pConfig := toPeerConfig(peer)
@ -397,7 +390,6 @@ func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty,
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *Server) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.ManagementService_SyncServer) error {
networkMap, err := s.accountManager.GetNetworkMap(peer.Key)
if err != nil {
log.Warnf("error getting a list of peers for a peer %s", peer.Key)
@ -436,7 +428,6 @@ func (s *Server) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.Mana
// This is used for initiating an Oauth 2 device authorization grant flow
// which will be used by our clients to Login
func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)

View File

@ -13,20 +13,33 @@ import (
log "github.com/sirupsen/logrus"
)
// Groups is a handler that returns groups of the account
type Groups struct {
accountManager server.AccountManager
authAudience string
jwtExtractor jwtclaims.ClaimsExtractor
}
// GroupResponse is a response sent to the client
type GroupResponse struct {
ID string
Name string
Peers []GroupPeerResponse `json:",omitempty"`
}
// GroupPeerResponse is a response sent to the client
type GroupPeerResponse struct {
Key string
Name string
}
// GroupRequest to create or update group
type GroupRequest struct {
ID string
Name string
Peers []string
}
// Groups is a handler that returns groups of the account
type Groups struct {
jwtExtractor jwtclaims.ClaimsExtractor
accountManager server.AccountManager
authAudience string
}
func NewGroups(accountManager server.AccountManager, authAudience string) *Groups {
return &Groups{
accountManager: accountManager,
@ -44,7 +57,12 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
return
}
writeJSONObject(w, account.Groups)
var groups []*GroupResponse
for _, g := range account.Groups {
groups = append(groups, toGroupResponse(account, g))
}
writeJSONObject(w, groups)
}
func (h *Groups) CreateOrUpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
@ -54,7 +72,7 @@ func (h *Groups) CreateOrUpdateGroupHandler(w http.ResponseWriter, r *http.Reque
return
}
var req server.Group
var req GroupRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@ -64,13 +82,19 @@ func (h *Groups) CreateOrUpdateGroupHandler(w http.ResponseWriter, r *http.Reque
req.ID = xid.New().String()
}
if err := h.accountManager.SaveGroup(account.Id, &req); err != nil {
group := server.Group{
ID: req.ID,
Name: req.Name,
Peers: req.Peers,
}
if err := h.accountManager.SaveGroup(account.Id, &group); err != nil {
log.Errorf("failed updating group %s under account %s %v", req.ID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
writeJSONObject(w, &req)
writeJSONObject(w, toGroupResponse(account, &group))
}
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
@ -117,7 +141,7 @@ func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
return
}
writeJSONObject(w, group)
writeJSONObject(w, toGroupResponse(account, group))
default:
http.Error(w, "", http.StatusNotFound)
}
@ -133,3 +157,29 @@ func (h *Groups) getGroupAccount(r *http.Request) (*server.Account, error) {
return account, nil
}
func toGroupResponse(account *server.Account, group *server.Group) *GroupResponse {
cache := make(map[string]GroupPeerResponse)
gr := GroupResponse{
ID: group.ID,
Name: group.Name,
}
for _, pid := range group.Peers {
peerResp, ok := cache[pid]
if !ok {
peer, ok := account.Peers[pid]
if !ok {
continue
}
peerResp = GroupPeerResponse{
Key: peer.Key,
Name: peer.Name,
}
cache[pid] = peerResp
}
gr.Peers = append(gr.Peers, peerResp)
}
return &gr
}

View File

@ -0,0 +1,211 @@
package handler
import (
"encoding/json"
"fmt"
"net/http"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/rs/xid"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
)
const FlowBidirectString = "bidirect"
// RuleResponse is a response sent to the client
type RuleResponse struct {
ID string
Name string
Source []RuleGroupResponse
Destination []RuleGroupResponse
Flow string
}
// RuleGroupResponse is a response sent to the client
type RuleGroupResponse struct {
ID string
Name string
PeersCount int
}
// RuleRequest to create or update rule
type RuleRequest struct {
ID string
Name string
Source []string
Destination []string
Flow string
}
// Rules is a handler that returns rules of the account
type Rules struct {
jwtExtractor jwtclaims.ClaimsExtractor
accountManager server.AccountManager
authAudience string
}
func NewRules(accountManager server.AccountManager, authAudience string) *Rules {
return &Rules{
accountManager: accountManager,
authAudience: authAudience,
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
}
}
// GetAllRulesHandler list for the account
func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
account, err := h.getRuleAccount(r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
var rules []*RuleResponse
for _, r := range account.Rules {
rules = append(rules, toRuleResponse(account, r))
}
writeJSONObject(w, rules)
}
func (h *Rules) CreateOrUpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := h.getRuleAccount(r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
var req RuleRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if r.Method == http.MethodPost {
req.ID = xid.New().String()
}
rule := server.Rule{
ID: req.ID,
Name: req.Name,
Source: req.Source,
Destination: req.Destination,
}
switch req.Flow {
case FlowBidirectString:
rule.Flow = server.TrafficFlowBidirect
default:
http.Error(w, "unknown flow type", http.StatusBadRequest)
return
}
if err := h.accountManager.SaveRule(account.Id, &rule); err != nil {
log.Errorf("failed updating rule %s under account %s %v", req.ID, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
writeJSONObject(w, &req)
}
func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := h.getRuleAccount(r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
aID := account.Id
rID := mux.Vars(r)["id"]
if len(rID) == 0 {
http.Error(w, "invalid rule ID", http.StatusBadRequest)
return
}
if err := h.accountManager.DeleteRule(aID, rID); err != nil {
log.Errorf("failed delete rule %s under account %s %v", rID, aID, err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
writeJSONObject(w, "")
}
func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := h.getRuleAccount(r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
switch r.Method {
case http.MethodGet:
ruleID := mux.Vars(r)["id"]
if len(ruleID) == 0 {
http.Error(w, "invalid rule ID", http.StatusBadRequest)
return
}
rule, err := h.accountManager.GetRule(account.Id, ruleID)
if err != nil {
http.Error(w, "rule not found", http.StatusNotFound)
return
}
writeJSONObject(w, toRuleResponse(account, rule))
default:
http.Error(w, "", http.StatusNotFound)
}
}
func (h *Rules) getRuleAccount(r *http.Request) (*server.Account, error) {
jwtClaims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience)
account, err := h.accountManager.GetAccountWithAuthorizationClaims(jwtClaims)
if err != nil {
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
}
return account, nil
}
func toRuleResponse(account *server.Account, rule *server.Rule) *RuleResponse {
gr := RuleResponse{
ID: rule.ID,
Name: rule.Name,
}
switch rule.Flow {
case server.TrafficFlowBidirect:
gr.Flow = FlowBidirectString
default:
gr.Flow = "unknown"
}
for _, gid := range rule.Source {
if group, ok := account.Groups[gid]; ok {
gr.Source = append(gr.Source, RuleGroupResponse{
ID: group.ID,
Name: group.Name,
PeersCount: len(group.Peers),
})
}
}
for _, gid := range rule.Destination {
if group, ok := account.Groups[gid]; ok {
gr.Destination = append(gr.Destination, RuleGroupResponse{
ID: group.ID,
Name: group.Name,
PeersCount: len(group.Peers),
})
}
}
return &gr
}

View File

@ -0,0 +1,211 @@
package handler
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
)
func initRulesTestData(rules ...*server.Rule) *Rules {
return &Rules{
accountManager: &mock_server.MockAccountManager{
SaveRuleFunc: func(_ string, rule *server.Rule) error {
if !strings.HasPrefix(rule.ID, "id-") {
rule.ID = "id-was-set"
}
return nil
},
GetRuleFunc: func(_, ruleID string) (*server.Rule, error) {
if ruleID != "idoftherule" {
return nil, fmt.Errorf("not found")
}
return &server.Rule{
ID: "idoftherule",
Name: "Rule",
Source: []string{"idofsrcrule"},
Destination: []string{"idofdestrule"},
Flow: server.TrafficFlowBidirect,
}, nil
},
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
}, nil
},
},
authAudience: "",
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
}
},
},
}
}
func TestRulesGetRule(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
requestType string
requestPath string
requestBody io.Reader
}{
{
name: "GetRule OK",
expectedBody: true,
requestType: http.MethodGet,
requestPath: "/api/rules/idoftherule",
expectedStatus: http.StatusOK,
},
{
name: "GetRule not found",
requestType: http.MethodGet,
requestPath: "/api/rules/notexists",
expectedStatus: http.StatusNotFound,
},
}
rule := &server.Rule{
ID: "idoftherule",
Name: "Rule",
}
p := initRulesTestData(rule)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/rules/{id}", p.GetRuleHandler).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v",
status, tc.expectedStatus)
return
}
if !tc.expectedBody {
return
}
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
var got RuleResponse
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, got.ID, rule.ID)
assert.Equal(t, got.Name, rule.Name)
})
}
}
func TestRulesSaveRule(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
expectedRule *server.Rule
requestType string
requestPath string
requestBody io.Reader
}{
{
name: "SaveRule POST OK",
requestType: http.MethodPost,
requestPath: "/api/rules",
requestBody: bytes.NewBuffer(
[]byte(`{"Name":"Default POSTed Rule","Flow":"bidirect"}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRule: &server.Rule{
ID: "id-was-set",
Name: "Default POSTed Rule",
Flow: server.TrafficFlowBidirect,
},
},
{
name: "SaveRule PUT OK",
requestType: http.MethodPut,
requestPath: "/api/rules",
requestBody: bytes.NewBuffer(
[]byte(`{"ID":"id-existed","Name":"Default POSTed Rule","Flow":"bidirect"}`)),
expectedStatus: http.StatusOK,
expectedRule: &server.Rule{
ID: "id-existed",
Name: "Default POSTed Rule",
Flow: server.TrafficFlowBidirect,
},
},
}
p := initRulesTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/rules", p.CreateOrUpdateRuleHandler).Methods("PUT", "POST")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
if !tc.expectedBody {
return
}
got := &RuleRequest{}
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
if tc.requestType != http.MethodPost {
assert.Equal(t, got.ID, tc.expectedRule.ID)
}
assert.Equal(t, got.Name, tc.expectedRule.Name)
assert.Equal(t, got.Flow, "bidirect")
})
}
}

View File

@ -96,6 +96,7 @@ func (s *Server) Start() error {
r.Use(jwtMiddleware.Handler, corsMiddleware.Handler)
groupsHandler := handler.NewGroups(s.accountManager, s.config.AuthAudience)
rulesHandler := handler.NewRules(s.accountManager, s.config.AuthAudience)
peersHandler := handler.NewPeers(s.accountManager, s.config.AuthAudience)
keysHandler := handler.NewSetupKeysHandler(s.accountManager, s.config.AuthAudience)
r.HandleFunc("/api/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS")
@ -112,6 +113,12 @@ func (s *Server) Start() error {
r.HandleFunc("/api/setup-keys/{id}", keysHandler.HandleKey).
Methods("GET", "PUT", "DELETE", "OPTIONS")
r.HandleFunc("/api/rules", rulesHandler.GetAllRulesHandler).Methods("GET", "OPTIONS")
r.HandleFunc("/api/rules", rulesHandler.CreateOrUpdateRuleHandler).
Methods("POST", "PUT", "OPTIONS")
r.HandleFunc("/api/rules/{id}", rulesHandler.GetRuleHandler).Methods("GET", "OPTIONS")
r.HandleFunc("/api/rules/{id}", rulesHandler.DeleteRuleHandler).Methods("DELETE", "OPTIONS")
r.HandleFunc("/api/groups", groupsHandler.GetAllGroupsHandler).Methods("GET", "OPTIONS")
r.HandleFunc("/api/groups", groupsHandler.CreateOrUpdateGroupHandler).
Methods("POST", "PUT", "OPTIONS")

View File

@ -3,6 +3,13 @@ package server
import (
"context"
"fmt"
"net"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/util"
@ -11,12 +18,6 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"net"
"os"
"path/filepath"
"runtime"
"testing"
"time"
)
var (
@ -39,8 +40,7 @@ const (
// registerPeers registers peersNum peers on the management service and returns their Wireguard keys
func registerPeers(peersNum int, client mgmtProto.ManagementServiceClient) ([]*wgtypes.Key, error) {
var peers = []*wgtypes.Key{}
peers := []*wgtypes.Key{}
for i := 0; i < peersNum; i++ {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
@ -60,7 +60,6 @@ func registerPeers(peersNum int, client mgmtProto.ManagementServiceClient) ([]*w
// getServerKey gets Management Service Wireguard public key
func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error) {
keyResp, err := client.GetServerKey(context.TODO(), &mgmtProto.Empty{})
if err != nil {
return nil, err
@ -75,7 +74,6 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error
}
func Test_SyncProtocol(t *testing.T) {
dir := t.TempDir()
err := util.CopyFileContents("testdata/store.json", filepath.Join(dir, "store.json"))
if err != nil {
@ -263,7 +261,6 @@ func Test_SyncProtocol(t *testing.T) {
}
func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServiceClient) (*mgmtProto.LoginResponse, error) {
serverKey, err := getServerKey(client)
if err != nil {
return nil, err
@ -298,11 +295,9 @@ func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServ
}
return loginResp, nil
}
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
testingServerKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
@ -362,7 +357,6 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
mgmtServer := &Server{
wgKey: testingServerKey,
config: &Config{
@ -397,7 +391,6 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
}
func startManagement(t *testing.T, port int, config *Config) (*grpc.Server, error) {
lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
return nil, err
@ -408,7 +401,10 @@ func startManagement(t *testing.T, port int, config *Config) (*grpc.Server, erro
return nil, err
}
peersUpdateManager := NewPeersUpdateManager()
accountManager := NewManager(store, peersUpdateManager, nil)
accountManager, err := BuildManager(store, peersUpdateManager, nil)
if err != nil {
return nil, err
}
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager)
if err != nil {

View File

@ -2,8 +2,6 @@ package server_test
import (
"context"
server "github.com/netbirdio/netbird/management/server"
"google.golang.org/grpc/credentials/insecure"
"io/ioutil"
"math/rand"
"net"
@ -13,6 +11,9 @@ import (
sync2 "sync"
"time"
server "github.com/netbirdio/netbird/management/server"
"google.golang.org/grpc/credentials/insecure"
pb "github.com/golang/protobuf/proto" //nolint
"github.com/netbirdio/netbird/encryption"
log "github.com/sirupsen/logrus"
@ -31,7 +32,6 @@ const (
)
var _ = Describe("Management service", func() {
var (
addr string
s *grpc.Server
@ -66,7 +66,6 @@ var _ = Describe("Management service", func() {
Expect(err).NotTo(HaveOccurred())
serverPubKey, err = wgtypes.ParseKey(resp.Key)
Expect(err).NotTo(HaveOccurred())
})
AfterEach(func() {
@ -78,7 +77,6 @@ var _ = Describe("Management service", func() {
Context("when calling IsHealthy endpoint", func() {
Specify("a non-error result is returned", func() {
healthy, err := client.IsHealthy(context.TODO(), &mgmtProto.Empty{})
Expect(err).NotTo(HaveOccurred())
@ -87,7 +85,6 @@ var _ = Describe("Management service", func() {
})
Context("when calling Sync endpoint", func() {
Context("when there is a new peer registered", func() {
Specify("a proper configuration is returned", func() {
key, _ := wgtypes.GenerateKey()
@ -168,7 +165,6 @@ var _ = Describe("Management service", func() {
Expect(resp.GetRemotePeers()).To(HaveLen(2))
peers := []string{resp.GetRemotePeers()[0].WgPubKey, resp.GetRemotePeers()[1].WgPubKey}
Expect(peers).To(ContainElements(key1.PublicKey().String(), key2.PublicKey().String()))
})
})
@ -211,7 +207,6 @@ var _ = Describe("Management service", func() {
resp = &mgmtProto.SyncResponse{}
err = pb.Unmarshal(decryptedBytes, resp)
wg.Done()
}()
// register a new peer
@ -229,7 +224,6 @@ var _ = Describe("Management service", func() {
Context("when calling GetServerKey endpoint", func() {
Specify("a public Wireguard key of the service is returned", func() {
resp, err := client.GetServerKey(context.TODO(), &mgmtProto.Empty{})
Expect(err).NotTo(HaveOccurred())
@ -237,19 +231,16 @@ var _ = Describe("Management service", func() {
Expect(resp.Key).ToNot(BeNil())
Expect(resp.ExpiresAt).ToNot(BeNil())
//check if the key is a valid Wireguard key
// check if the key is a valid Wireguard key
key, err := wgtypes.ParseKey(resp.Key)
Expect(err).NotTo(HaveOccurred())
Expect(key).ToNot(BeNil())
})
})
Context("when calling Login endpoint", func() {
Context("with an invalid setup key", func() {
Specify("an error is returned", func() {
key, _ := wgtypes.GenerateKey()
message, err := encryption.EncryptMessage(serverPubKey, key, &mgmtProto.LoginRequest{SetupKey: "invalid setup key"})
Expect(err).NotTo(HaveOccurred())
@ -261,24 +252,20 @@ var _ = Describe("Management service", func() {
Expect(err).To(HaveOccurred())
Expect(resp).To(BeNil())
})
})
Context("with a valid setup key", func() {
It("a non error result is returned", func() {
key, _ := wgtypes.GenerateKey()
resp := loginPeerWithValidSetupKey(serverPubKey, key, client)
Expect(resp).ToNot(BeNil())
})
})
Context("with a registered peer", func() {
It("a non error result is returned", func() {
key, _ := wgtypes.GenerateKey()
regResp := loginPeerWithValidSetupKey(serverPubKey, key, client)
Expect(regResp).NotTo(BeNil())
@ -324,7 +311,6 @@ var _ = Describe("Management service", func() {
Context("when there are 50 peers registered under one account", func() {
Context("when there are 10 more peers registered under the same account", func() {
Specify("all of the 50 peers will get updates of 10 newly registered peers", func() {
initialPeers := 20
additionalPeers := 10
@ -369,7 +355,7 @@ var _ = Describe("Management service", func() {
err = pb.Unmarshal(decryptedBytes, resp)
Expect(err).NotTo(HaveOccurred())
if len(resp.GetRemotePeers()) > 0 {
//only consider peer updates
// only consider peer updates
wg.Done()
}
}
@ -397,7 +383,6 @@ var _ = Describe("Management service", func() {
Context("when there are peers registered under one account concurrently", func() {
Specify("then there are no duplicate IPs", func() {
initialPeers := 30
ipChannel := make(chan string, 20)
@ -423,7 +408,6 @@ var _ = Describe("Management service", func() {
Expect(err).NotTo(HaveOccurred())
ipChannel <- resp.GetPeerConfig().Address
}()
}
@ -443,6 +427,7 @@ var _ = Describe("Management service", func() {
})
func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse {
defer GinkgoRecover()
meta := &mgmtProto.PeerSystemMeta{
Hostname: key.PublicKey().String(),
@ -467,7 +452,6 @@ func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, clien
err = encryption.DecryptMessage(serverPubKey, key, resp.Body, loginResp)
Expect(err).NotTo(HaveOccurred())
return loginResp
}
func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn) {
@ -496,7 +480,10 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
}
peersUpdateManager := server.NewPeersUpdateManager()
accountManager := server.NewManager(store, peersUpdateManager, nil)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil)
if err != nil {
log.Fatalf("failed creating a manager: %v", err)
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager)
Expect(err).NotTo(HaveOccurred())

View File

@ -33,6 +33,10 @@ type MockAccountManager struct {
GroupAddPeerFunc func(accountID, groupID, peerKey string) error
GroupDeletePeerFunc func(accountID, groupID, peerKey string) error
GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error)
GetRuleFunc func(accountID, ruleID string) (*server.Rule, error)
SaveRuleFunc func(accountID string, rule *server.Rule) error
DeleteRuleFunc func(accountID, ruleID string) error
ListRulesFunc func(accountID string) ([]*server.Rule, error)
GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error)
}
@ -41,7 +45,6 @@ func (am *MockAccountManager) GetUsersFromAccount(accountID string) ([]*server.U
return am.GetUsersFromAccountFunc(accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount not implemented")
}
func (am *MockAccountManager) GetOrCreateAccountByUser(
@ -207,7 +210,7 @@ func (am *MockAccountManager) SaveGroup(accountID string, group *server.Group) e
if am.SaveGroupFunc != nil {
return am.SaveGroupFunc(accountID, group)
}
return status.Errorf(codes.Unimplemented, "method UpdateGroup not implemented")
return status.Errorf(codes.Unimplemented, "method SaveGroup not implemented")
}
func (am *MockAccountManager) DeleteGroup(accountID, groupID string) error {
@ -244,3 +247,31 @@ func (am *MockAccountManager) GroupListPeers(accountID, groupID string) ([]*serv
}
return nil, status.Errorf(codes.Unimplemented, "method GroupListPeers not implemented")
}
func (am *MockAccountManager) GetRule(accountID, ruleID string) (*server.Rule, error) {
if am.GetRuleFunc != nil {
return am.GetRuleFunc(accountID, ruleID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetRule not implemented")
}
func (am *MockAccountManager) SaveRule(accountID string, rule *server.Rule) error {
if am.SaveRuleFunc != nil {
return am.SaveRuleFunc(accountID, rule)
}
return status.Errorf(codes.Unimplemented, "method SaveRule not implemented")
}
func (am *MockAccountManager) DeleteRule(accountID, ruleID string) error {
if am.DeleteRuleFunc != nil {
return am.DeleteRuleFunc(accountID, ruleID)
}
return status.Errorf(codes.Unimplemented, "method DeleteRule not implemented")
}
func (am *MockAccountManager) ListRules(accountID string) ([]*server.Rule, error) {
if am.ListRulesFunc != nil {
return am.ListRulesFunc(accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListRules not implemented")
}

View File

@ -1,12 +1,13 @@
package server
import (
"github.com/netbirdio/netbird/management/proto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net"
"strings"
"time"
"github.com/netbirdio/netbird/management/proto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// PeerSystemMeta is a metadata of a Peer machine system
@ -21,31 +22,31 @@ type PeerSystemMeta struct {
}
type PeerStatus struct {
//LastSeen is the last time peer was connected to the management service
// LastSeen is the last time peer was connected to the management service
LastSeen time.Time
//Connected indicates whether peer is connected to the management service or not
// Connected indicates whether peer is connected to the management service or not
Connected bool
}
//Peer represents a machine connected to the network.
//The Peer is a Wireguard peer identified by a public key
// Peer represents a machine connected to the network.
// The Peer is a Wireguard peer identified by a public key
type Peer struct {
//Wireguard public key
// Wireguard public key
Key string
//A setup key this peer was registered with
// A setup key this peer was registered with
SetupKey string
//IP address of the Peer
// IP address of the Peer
IP net.IP
//Meta is a Peer system meta data
// Meta is a Peer system meta data
Meta PeerSystemMeta
//Name is peer's name (machine name)
// Name is peer's name (machine name)
Name string
Status *PeerStatus
//The user ID that registered the peer
// The user ID that registered the peer
UserID string
}
//Copy copies Peer object
// Copy copies Peer object
func (p *Peer) Copy() *Peer {
return &Peer{
Key: p.Key,
@ -58,7 +59,7 @@ func (p *Peer) Copy() *Peer {
}
}
//GetPeer returns a peer from a Store
// GetPeer returns a peer from a Store
func (am *DefaultAccountManager) GetPeer(peerKey string) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
@ -71,7 +72,7 @@ func (am *DefaultAccountManager) GetPeer(peerKey string) (*Peer, error) {
return peer, nil
}
//MarkPeerConnected marks peer as connected (true) or disconnected (false)
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(peerKey string, connected bool) error {
am.mux.Lock()
defer am.mux.Unlock()
@ -96,8 +97,12 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerKey string, connected boo
return nil
}
//RenamePeer changes peer's name
func (am *DefaultAccountManager) RenamePeer(accountId string, peerKey string, newName string) (*Peer, error) {
// RenamePeer changes peer's name
func (am *DefaultAccountManager) RenamePeer(
accountId string,
peerKey string,
newName string,
) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
@ -116,7 +121,7 @@ func (am *DefaultAccountManager) RenamePeer(accountId string, peerKey string, ne
return peerCopy, nil
}
//DeletePeer removes peer from the account by it's IP
// DeletePeer removes peer from the account by it's IP
func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
@ -149,12 +154,13 @@ func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
},
}})
},
})
if err != nil {
return nil, err
}
//notify other peers of the change
// notify other peers of the change
peers, err := am.Store.GetAccountPeers(accountId)
if err != nil {
return nil, err
@ -180,7 +186,8 @@ func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*
RemotePeers: update,
RemotePeersIsEmpty: len(update) == 0,
},
}})
},
})
if err != nil {
return nil, err
}
@ -190,7 +197,7 @@ func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*
return peer, nil
}
//GetPeerByIP returns peer by it's IP
// GetPeerByIP returns peer by it's IP
func (am *DefaultAccountManager) GetPeerByIP(accountId string, peerIP string) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
@ -220,10 +227,46 @@ func (am *DefaultAccountManager) GetNetworkMap(peerKey string) (*NetworkMap, err
}
var res []*Peer
for _, peer := range account.Peers {
// exclude original peer
if peer.Key != peerKey {
res = append(res, peer.Copy())
srcRules, err := am.Store.GetPeerSrcRules(account.Id, peerKey)
if err != nil {
return &NetworkMap{
Peers: res,
Network: account.Network.Copy(),
}, nil
}
dstRules, err := am.Store.GetPeerDstRules(account.Id, peerKey)
if err != nil {
return &NetworkMap{
Peers: res,
Network: account.Network.Copy(),
}, nil
}
groups := map[string]*Group{}
for _, r := range srcRules {
if r.Flow == TrafficFlowBidirect {
for _, gid := range r.Destination {
groups[gid] = account.Groups[gid]
}
}
}
for _, r := range dstRules {
if r.Flow == TrafficFlowBidirect {
for _, gid := range r.Source {
groups[gid] = account.Groups[gid]
}
}
}
for _, g := range groups {
for _, pid := range g.Peers {
peer := account.Peers[pid]
// exclude original peer
if peer.Key != peerKey {
res = append(res, peer.Copy())
}
}
}
@ -240,7 +283,11 @@ func (am *DefaultAccountManager) GetNetworkMap(peerKey string) (*NetworkMap, err
// to it. We also add the User ID to the peer metadata to identify registrant.
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
// The peer property is just a placeholder for the Peer properties to pass further
func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *Peer) (*Peer, error) {
func (am *DefaultAccountManager) AddPeer(
setupKey string,
userID string,
peer *Peer,
) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
@ -252,17 +299,28 @@ func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *P
if len(upperKey) != 0 {
account, err = am.Store.GetAccountBySetupKey(upperKey)
if err != nil {
return nil, status.Errorf(codes.NotFound, "unable to register peer, unable to find account with setupKey %s", upperKey)
return nil, status.Errorf(
codes.NotFound,
"unable to register peer, unable to find account with setupKey %s",
upperKey,
)
}
sk = getAccountSetupKeyByKey(account, upperKey)
if sk == nil {
// shouldn't happen actually
return nil, status.Errorf(codes.NotFound, "unable to register peer, unknown setupKey %s", upperKey)
return nil, status.Errorf(
codes.NotFound,
"unable to register peer, unknown setupKey %s",
upperKey,
)
}
if !sk.IsValid() {
return nil, status.Errorf(codes.FailedPrecondition, "unable to register peer, its setup key is invalid (expired, overused or revoked)")
return nil, status.Errorf(
codes.FailedPrecondition,
"unable to register peer, its setup key is invalid (expired, overused or revoked)",
)
}
} else if len(userID) != 0 {
@ -293,6 +351,13 @@ func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *P
Status: &PeerStatus{Connected: false, LastSeen: time.Now()},
}
// add peer to 'All' group
group, err := account.GetGroupAll()
if err != nil {
return nil, err
}
group.Peers = append(group.Peers, newPeer.Key)
account.Peers[newPeer.Key] = newPeer
if len(upperKey) != 0 {
account.SetupKeys[sk.Key] = sk.IncrementUsage()
@ -305,5 +370,4 @@ func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *P
}
return newPeer, nil
}

View File

@ -1,8 +1,10 @@
package server
import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"testing"
"github.com/rs/xid"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func TestAccountManager_GetNetworkMap(t *testing.T) {
@ -70,7 +72,151 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
}
if networkMap.Peers[0].Key != peerKey2.PublicKey().String() {
t.Errorf("expecting Account NetworkMap to have peer with a key %s, got %s", peerKey2.PublicKey().String(), networkMap.Peers[0].Key)
t.Errorf(
"expecting Account NetworkMap to have peer with a key %s, got %s",
peerKey2.PublicKey().String(),
networkMap.Peers[0].Key,
)
}
}
func TestAccountManager_GetNetworkMapWithRule(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
}
expectedId := "test_account"
userId := "account_creator"
account, err := manager.AddAccount(expectedId, userId, "")
if err != nil {
t.Fatal(err)
}
var setupKey *SetupKey
for _, key := range account.SetupKeys {
if key.Type == SetupKeyReusable {
setupKey = key
}
}
peerKey1, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
_, err = manager.AddPeer(setupKey.Key, "", &Peer{
Key: peerKey1.PublicKey().String(),
Meta: PeerSystemMeta{},
Name: "test-peer-2",
})
if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err)
return
}
peerKey2, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
_, err = manager.AddPeer(setupKey.Key, "", &Peer{
Key: peerKey2.PublicKey().String(),
Meta: PeerSystemMeta{},
Name: "test-peer-2",
})
if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err)
return
}
rules, err := manager.ListRules(account.Id)
if err != nil {
t.Errorf("expecting to get a list of rules, got failure %v", err)
return
}
err = manager.DeleteRule(account.Id, rules[0].ID)
if err != nil {
t.Errorf("expecting to delete 1 group, got failure %v", err)
return
}
var (
group1 Group
group2 Group
rule Rule
)
group1.ID = xid.New().String()
group2.ID = xid.New().String()
group1.Name = "src"
group2.Name = "dst"
rule.ID = xid.New().String()
group1.Peers = append(group1.Peers, peerKey1.PublicKey().String())
group2.Peers = append(group2.Peers, peerKey2.PublicKey().String())
err = manager.SaveGroup(account.Id, &group1)
if err != nil {
t.Errorf("expecting group1 to be added, got failure %v", err)
return
}
err = manager.SaveGroup(account.Id, &group2)
if err != nil {
t.Errorf("expecting group2 to be added, got failure %v", err)
return
}
rule.Name = "test"
rule.Source = append(rule.Source, group1.ID)
rule.Destination = append(rule.Destination, group2.ID)
rule.Flow = TrafficFlowBidirect
err = manager.SaveRule(account.Id, &rule)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return
}
networkMap1, err := manager.GetNetworkMap(peerKey1.PublicKey().String())
if err != nil {
t.Fatal(err)
return
}
if len(networkMap1.Peers) != 1 {
t.Errorf(
"expecting Account NetworkMap to have 1 peers, got %v: %v",
len(networkMap1.Peers),
networkMap1.Peers,
)
}
if networkMap1.Peers[0].Key != peerKey2.PublicKey().String() {
t.Errorf(
"expecting Account NetworkMap to have peer with a key %s, got %s",
peerKey2.PublicKey().String(),
networkMap1.Peers[0].Key,
)
}
networkMap2, err := manager.GetNetworkMap(peerKey2.PublicKey().String())
if err != nil {
t.Fatal(err)
return
}
if len(networkMap2.Peers) != 1 {
t.Errorf("expecting Account NetworkMap to have 1 peers, got %v", len(networkMap2.Peers))
}
if len(networkMap2.Peers) > 0 && networkMap2.Peers[0].Key != peerKey1.PublicKey().String() {
t.Errorf(
"expecting Account NetworkMap to have peer with a key %s, got %s",
peerKey1.PublicKey().String(),
networkMap2.Peers[0].Key,
)
}
}

107
management/server/rule.go Normal file
View File

@ -0,0 +1,107 @@
package server
import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// TrafficFlowType defines allowed direction of the traffic in the rule
type TrafficFlowType int
const (
// TrafficFlowBidirect allows traffic to both direction
TrafficFlowBidirect TrafficFlowType = iota
)
// Rule of ACL for groups
type Rule struct {
// ID of the rule
ID string
// Name of the rule visible in the UI
Name string
// Source list of groups IDs of peers
Source []string
// Destination list of groups IDs of peers
Destination []string
// Flow of the traffic allowed by the rule
Flow TrafficFlowType
}
func (r *Rule) Copy() *Rule {
return &Rule{
ID: r.ID,
Name: r.Name,
Source: r.Source[:],
Destination: r.Destination[:],
Flow: r.Flow,
}
}
// GetRule of ACL from the store
func (am *DefaultAccountManager) GetRule(accountID, ruleID string) (*Rule, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
rule, ok := account.Rules[ruleID]
if ok {
return rule, nil
}
return nil, status.Errorf(codes.NotFound, "rule with ID %s not found", ruleID)
}
// SaveRule of ACL in the store
func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
}
account.Rules[rule.ID] = rule
return am.Store.SaveAccount(account)
}
// DeleteRule of ACL from the store
func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
}
delete(account.Rules, ruleID)
return am.Store.SaveAccount(account)
}
// ListRules of ACL from the store
func (am *DefaultAccountManager) ListRules(accountID string) ([]*Rule, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
rules := make([]*Rule, 0, len(account.Rules))
for _, item := range account.Rules {
rules = append(rules, item)
}
return rules, nil
}

View File

@ -4,10 +4,13 @@ type Store interface {
GetPeer(peerKey string) (*Peer, error)
DeletePeer(accountId string, peerKey string) (*Peer, error)
SavePeer(accountId string, peer *Peer) error
GetAllAccounts() []*Account
GetAccount(accountId string) (*Account, error)
GetUserAccount(userId string) (*Account, error)
GetAccountPeers(accountId string) ([]*Peer, error)
GetPeerAccount(peerKey string) (*Account, error)
GetPeerSrcRules(accountId, peerKey string) ([]*Rule, error)
GetPeerDstRules(accountId, peerKey string) ([]*Rule, error)
GetAccountBySetupKey(setupKey string) (*Account, error)
GetAccountByPrivateDomain(domain string) (*Account, error)
SaveAccount(account *Account) error

View File

@ -58,6 +58,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
account = NewAccount(userId, lowerDomain)
account.Users[userId] = NewAdminUser(userId)
am.addAllGroup(account)
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed creating account")