mirror of
https://github.com/netbirdio/netbird.git
synced 2025-04-16 07:28:32 +02:00
Simplified Store Interface (#545)
This PR simplifies Store and FileStore by keeping just the Get and Save account methods. The AccountManager operates mostly around a single account, so it makes sense to fetch the whole account object from the store.
This commit is contained in:
parent
4321b71984
commit
d0c6d88971
@ -185,8 +185,6 @@ var (
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("metrics ", disableMetrics)
|
|
||||||
|
|
||||||
if !disableMetrics {
|
if !disableMetrics {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"net/netip"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
@ -58,7 +59,6 @@ type AccountManager interface {
|
|||||||
GetPeer(peerKey string) (*Peer, error)
|
GetPeer(peerKey string) (*Peer, error)
|
||||||
GetPeers(accountID, userID string) ([]*Peer, error)
|
GetPeers(accountID, userID string) ([]*Peer, error)
|
||||||
MarkPeerConnected(peerKey string, connected bool) error
|
MarkPeerConnected(peerKey string, connected bool) error
|
||||||
RenamePeer(accountId string, peerKey string, newName string) (*Peer, error)
|
|
||||||
DeletePeer(accountId string, peerKey string) (*Peer, error)
|
DeletePeer(accountId string, peerKey string) (*Peer, error)
|
||||||
GetPeerByIP(accountId string, peerIP string) (*Peer, error)
|
GetPeerByIP(accountId string, peerIP string) (*Peer, error)
|
||||||
UpdatePeer(accountID string, peer *Peer) (*Peer, error)
|
UpdatePeer(accountID string, peer *Peer) (*Peer, error)
|
||||||
@ -143,6 +143,132 @@ type UserInfo struct {
|
|||||||
Status string `json:"-"`
|
Status string `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPeersRoutes returns all active routes of provided peers
|
||||||
|
func (a *Account) GetPeersRoutes(givenPeers []*Peer) []*route.Route {
|
||||||
|
//TODO Peer.ID migration: we will need to replace search by Peer.ID here
|
||||||
|
routes := make([]*route.Route, 0)
|
||||||
|
for _, peer := range givenPeers {
|
||||||
|
peerRoutes := a.GetPeerRoutes(peer.Key)
|
||||||
|
activeRoutes := make([]*route.Route, 0)
|
||||||
|
for _, pr := range peerRoutes {
|
||||||
|
if pr.Enabled {
|
||||||
|
activeRoutes = append(activeRoutes, pr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(activeRoutes) > 0 {
|
||||||
|
routes = append(routes, activeRoutes...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return routes
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerRoutes returns a list of routes of a given peer
|
||||||
|
func (a *Account) GetPeerRoutes(peerPubKey string) []*route.Route {
|
||||||
|
//TODO Peer.ID migration: we will need to replace search by Peer.ID here
|
||||||
|
var routes []*route.Route
|
||||||
|
for _, r := range a.Routes {
|
||||||
|
if r.Peer == peerPubKey {
|
||||||
|
routes = append(routes, r)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return routes
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoutesByPrefix return list of routes by account and route prefix
|
||||||
|
func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route {
|
||||||
|
|
||||||
|
var routes []*route.Route
|
||||||
|
for _, r := range a.Routes {
|
||||||
|
if r.Network.String() == prefix.String() {
|
||||||
|
routes = append(routes, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerRules returns a list of source or destination rules of a given peer.
|
||||||
|
func (a *Account) GetPeerRules(peerPubKey string) (srcRules []*Rule, dstRules []*Rule) {
|
||||||
|
|
||||||
|
// Rules are group based so there is no direct access to peers.
|
||||||
|
// First, find all groups that the given peer belongs to
|
||||||
|
peerGroups := make(map[string]struct{})
|
||||||
|
|
||||||
|
for s, group := range a.Groups {
|
||||||
|
for _, peer := range group.Peers {
|
||||||
|
if peerPubKey == peer {
|
||||||
|
peerGroups[s] = struct{}{}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second, find all rules that have discovered source and destination groups
|
||||||
|
srcRulesMap := make(map[string]*Rule)
|
||||||
|
dstRulesMap := make(map[string]*Rule)
|
||||||
|
for _, rule := range a.Rules {
|
||||||
|
for _, g := range rule.Source {
|
||||||
|
if _, ok := peerGroups[g]; ok && srcRulesMap[rule.ID] == nil {
|
||||||
|
srcRules = append(srcRules, rule)
|
||||||
|
srcRulesMap[rule.ID] = rule
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, g := range rule.Destination {
|
||||||
|
if _, ok := peerGroups[g]; ok && dstRulesMap[rule.ID] == nil {
|
||||||
|
dstRules = append(dstRules, rule)
|
||||||
|
dstRulesMap[rule.ID] = rule
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return srcRules, dstRules
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeers returns a list of all Account peers
|
||||||
|
func (a *Account) GetPeers() []*Peer {
|
||||||
|
var peers []*Peer
|
||||||
|
for _, peer := range a.Peers {
|
||||||
|
peers = append(peers, peer)
|
||||||
|
}
|
||||||
|
return peers
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePeer saves new or replaces existing peer
|
||||||
|
func (a *Account) UpdatePeer(update *Peer) {
|
||||||
|
//TODO Peer.ID migration: we will need to replace search by Peer.ID here
|
||||||
|
a.Peers[update.Key] = update
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePeer deletes peer from the account cleaning up all the references
|
||||||
|
func (a *Account) DeletePeer(peerPubKey string) {
|
||||||
|
// TODO Peer.ID migration: we will need to replace search by Peer.ID here
|
||||||
|
|
||||||
|
// delete peer from groups
|
||||||
|
for _, g := range a.Groups {
|
||||||
|
for i, pk := range g.Peers {
|
||||||
|
if pk == peerPubKey {
|
||||||
|
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(a.Peers, peerPubKey)
|
||||||
|
a.Network.IncSerial()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindPeerByPubKey looks for a Peer by provided WireGuard public key in the Account or returns error if it wasn't found.
|
||||||
|
// It will return an object copy of the peer.
|
||||||
|
func (a *Account) FindPeerByPubKey(peerPubKey string) (*Peer, error) {
|
||||||
|
for _, peer := range a.Peers {
|
||||||
|
if peer.Key == peerPubKey {
|
||||||
|
return peer.Copy(), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, status.Errorf(codes.NotFound, "peer with the public key %s not found", peerPubKey)
|
||||||
|
}
|
||||||
|
|
||||||
// FindUser looks for a given user in the Account or returns error if user wasn't found.
|
// FindUser looks for a given user in the Account or returns error if user wasn't found.
|
||||||
func (a *Account) FindUser(userID string) (*User, error) {
|
func (a *Account) FindUser(userID string) (*User, error) {
|
||||||
user := a.Users[userID]
|
user := a.Users[userID]
|
||||||
@ -190,16 +316,19 @@ func (a *Account) Copy() *Account {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &Account{
|
return &Account{
|
||||||
Id: a.Id,
|
Id: a.Id,
|
||||||
CreatedBy: a.CreatedBy,
|
CreatedBy: a.CreatedBy,
|
||||||
SetupKeys: setupKeys,
|
Domain: a.Domain,
|
||||||
Network: a.Network.Copy(),
|
DomainCategory: a.DomainCategory,
|
||||||
Peers: peers,
|
IsDomainPrimaryAccount: a.IsDomainPrimaryAccount,
|
||||||
Users: users,
|
SetupKeys: setupKeys,
|
||||||
Groups: groups,
|
Network: a.Network.Copy(),
|
||||||
Rules: rules,
|
Peers: peers,
|
||||||
Routes: routes,
|
Users: users,
|
||||||
NameServerGroups: nsGroups,
|
Groups: groups,
|
||||||
|
Rules: rules,
|
||||||
|
Routes: routes,
|
||||||
|
NameServerGroups: nsGroups,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -699,7 +828,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetUserAccount(claims.UserId)
|
account, err := am.Store.GetAccountByUser(claims.UserId)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = am.handleExistingUserAccount(account, domainAccount, claims)
|
err = am.handleExistingUserAccount(account, domainAccount, claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
@ -957,6 +958,167 @@ func TestAccountManager_UpdatePeerMeta(t *testing.T) {
|
|||||||
assert.Equal(t, newMeta, p.Meta)
|
assert.Equal(t, newMeta, p.Meta)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccount_GetPeerRules(t *testing.T) {
|
||||||
|
|
||||||
|
groups := map[string]*Group{
|
||||||
|
"group_1": {
|
||||||
|
ID: "group_1",
|
||||||
|
Name: "group_1",
|
||||||
|
Peers: []string{"peer-1", "peer-2"},
|
||||||
|
},
|
||||||
|
"group_2": {
|
||||||
|
ID: "group_2",
|
||||||
|
Name: "group_2",
|
||||||
|
Peers: []string{"peer-2", "peer-3"},
|
||||||
|
},
|
||||||
|
"group_3": {
|
||||||
|
ID: "group_3",
|
||||||
|
Name: "group_3",
|
||||||
|
Peers: []string{"peer-4"},
|
||||||
|
},
|
||||||
|
"group_4": {
|
||||||
|
ID: "group_4",
|
||||||
|
Name: "group_4",
|
||||||
|
Peers: []string{"peer-1"},
|
||||||
|
},
|
||||||
|
"group_5": {
|
||||||
|
ID: "group_5",
|
||||||
|
Name: "group_5",
|
||||||
|
Peers: []string{"peer-1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
rules := map[string]*Rule{
|
||||||
|
"rule-1": {
|
||||||
|
ID: "rule-1",
|
||||||
|
Name: "rule-1",
|
||||||
|
Description: "rule-1",
|
||||||
|
Disabled: false,
|
||||||
|
Source: []string{"group_1", "group_5"},
|
||||||
|
Destination: []string{"group_2"},
|
||||||
|
Flow: 0,
|
||||||
|
},
|
||||||
|
"rule-2": {
|
||||||
|
ID: "rule-2",
|
||||||
|
Name: "rule-2",
|
||||||
|
Description: "rule-2",
|
||||||
|
Disabled: false,
|
||||||
|
Source: []string{"group_1"},
|
||||||
|
Destination: []string{"group_1"},
|
||||||
|
Flow: 0,
|
||||||
|
},
|
||||||
|
"rule-3": {
|
||||||
|
ID: "rule-3",
|
||||||
|
Name: "rule-3",
|
||||||
|
Description: "rule-3",
|
||||||
|
Disabled: false,
|
||||||
|
Source: []string{"group_3"},
|
||||||
|
Destination: []string{"group_3"},
|
||||||
|
Flow: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
Groups: groups,
|
||||||
|
Rules: rules,
|
||||||
|
}
|
||||||
|
|
||||||
|
srcRules, dstRules := account.GetPeerRules("peer-1")
|
||||||
|
|
||||||
|
assert.Equal(t, 2, len(srcRules))
|
||||||
|
assert.Equal(t, 1, len(dstRules))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore_GetRoutesByPrefix(t *testing.T) {
|
||||||
|
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
Routes: map[string]*route.Route{
|
||||||
|
"route-1": {
|
||||||
|
ID: "route-1",
|
||||||
|
Network: prefix,
|
||||||
|
NetID: "network-1",
|
||||||
|
Description: "network-1",
|
||||||
|
Peer: "peer-1",
|
||||||
|
NetworkType: 0,
|
||||||
|
Masquerade: false,
|
||||||
|
Metric: 999,
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
"route-2": {
|
||||||
|
ID: "route-2",
|
||||||
|
Network: prefix,
|
||||||
|
NetID: "network-1",
|
||||||
|
Description: "network-1",
|
||||||
|
Peer: "peer-2",
|
||||||
|
NetworkType: 0,
|
||||||
|
Masquerade: false,
|
||||||
|
Metric: 999,
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
routes := account.GetRoutesByPrefix(prefix)
|
||||||
|
|
||||||
|
assert.Len(t, routes, 2)
|
||||||
|
routeIDs := make(map[string]struct{}, 2)
|
||||||
|
for _, r := range routes {
|
||||||
|
routeIDs[r.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
assert.Contains(t, routeIDs, "route-1")
|
||||||
|
assert.Contains(t, routeIDs, "route-2")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccount_GetPeersRoutes(t *testing.T) {
|
||||||
|
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
account := &Account{
|
||||||
|
Peers: map[string]*Peer{
|
||||||
|
"peer-1": {Key: "peer-1"}, "peer-2": {Key: "peer-2"}, "peer-3": {Key: "peer-1"},
|
||||||
|
},
|
||||||
|
Routes: map[string]*route.Route{
|
||||||
|
"route-1": {
|
||||||
|
ID: "route-1",
|
||||||
|
Network: prefix,
|
||||||
|
NetID: "network-1",
|
||||||
|
Description: "network-1",
|
||||||
|
Peer: "peer-1",
|
||||||
|
NetworkType: 0,
|
||||||
|
Masquerade: false,
|
||||||
|
Metric: 999,
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
"route-2": {
|
||||||
|
ID: "route-2",
|
||||||
|
Network: prefix,
|
||||||
|
NetID: "network-1",
|
||||||
|
Description: "network-1",
|
||||||
|
Peer: "peer-2",
|
||||||
|
NetworkType: 0,
|
||||||
|
Masquerade: false,
|
||||||
|
Metric: 999,
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
routes := account.GetPeersRoutes([]*Peer{{Key: "peer-1"}, {Key: "peer-2"}, {Key: "non-existing-peer"}})
|
||||||
|
|
||||||
|
assert.Len(t, routes, 2)
|
||||||
|
routeIDs := make(map[string]struct{}, 2)
|
||||||
|
for _, r := range routes {
|
||||||
|
routeIDs[r.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
assert.Contains(t, routeIDs, "route-1")
|
||||||
|
assert.Contains(t, routeIDs, "route-2")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||||
store, err := createStore(t)
|
store, err := createStore(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1,9 +1,6 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@ -21,14 +18,10 @@ const storeFileName = "store.json"
|
|||||||
// FileStore represents an account storage backed by a file persisted to disk
|
// FileStore represents an account storage backed by a file persisted to disk
|
||||||
type FileStore struct {
|
type FileStore struct {
|
||||||
Accounts map[string]*Account
|
Accounts map[string]*Account
|
||||||
SetupKeyId2AccountId map[string]string `json:"-"`
|
SetupKeyID2AccountID map[string]string `json:"-"`
|
||||||
PeerKeyId2AccountId map[string]string `json:"-"`
|
PeerKeyID2AccountID map[string]string `json:"-"`
|
||||||
UserId2AccountId map[string]string `json:"-"`
|
UserID2AccountID map[string]string `json:"-"`
|
||||||
PrivateDomain2AccountId map[string]string `json:"-"`
|
PrivateDomain2AccountID map[string]string `json:"-"`
|
||||||
PeerKeyId2SrcRulesId map[string]map[string]struct{} `json:"-"`
|
|
||||||
PeerKeyId2DstRulesId map[string]map[string]struct{} `json:"-"`
|
|
||||||
PeerKeyID2RouteIDs map[string]map[string]struct{} `json:"-"`
|
|
||||||
AccountPrefix2RouteIDs map[string]map[string][]string `json:"-"`
|
|
||||||
InstallationID string
|
InstallationID string
|
||||||
|
|
||||||
// mutex to synchronise Store read/write operations
|
// mutex to synchronise Store read/write operations
|
||||||
@ -43,7 +36,7 @@ func NewStore(dataDir string) (*FileStore, error) {
|
|||||||
return restore(filepath.Join(dataDir, storeFileName))
|
return restore(filepath.Join(dataDir, storeFileName))
|
||||||
}
|
}
|
||||||
|
|
||||||
// restore restores the state of the store from the file.
|
// restore the state of the store from the file.
|
||||||
// Creates a new empty store file if doesn't exist
|
// Creates a new empty store file if doesn't exist
|
||||||
func restore(file string) (*FileStore, error) {
|
func restore(file string) (*FileStore, error) {
|
||||||
if _, err := os.Stat(file); os.IsNotExist(err) {
|
if _, err := os.Stat(file); os.IsNotExist(err) {
|
||||||
@ -51,14 +44,10 @@ func restore(file string) (*FileStore, error) {
|
|||||||
s := &FileStore{
|
s := &FileStore{
|
||||||
Accounts: make(map[string]*Account),
|
Accounts: make(map[string]*Account),
|
||||||
mux: sync.Mutex{},
|
mux: sync.Mutex{},
|
||||||
SetupKeyId2AccountId: make(map[string]string),
|
SetupKeyID2AccountID: make(map[string]string),
|
||||||
PeerKeyId2AccountId: make(map[string]string),
|
PeerKeyID2AccountID: make(map[string]string),
|
||||||
UserId2AccountId: make(map[string]string),
|
UserID2AccountID: make(map[string]string),
|
||||||
PrivateDomain2AccountId: make(map[string]string),
|
PrivateDomain2AccountID: make(map[string]string),
|
||||||
PeerKeyId2SrcRulesId: make(map[string]map[string]struct{}),
|
|
||||||
PeerKeyID2RouteIDs: make(map[string]map[string]struct{}),
|
|
||||||
PeerKeyId2DstRulesId: make(map[string]map[string]struct{}),
|
|
||||||
AccountPrefix2RouteIDs: make(map[string]map[string][]string),
|
|
||||||
storeFile: file,
|
storeFile: file,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -77,287 +66,78 @@ func restore(file string) (*FileStore, error) {
|
|||||||
|
|
||||||
store := read.(*FileStore)
|
store := read.(*FileStore)
|
||||||
store.storeFile = file
|
store.storeFile = file
|
||||||
store.SetupKeyId2AccountId = make(map[string]string)
|
store.SetupKeyID2AccountID = make(map[string]string)
|
||||||
store.PeerKeyId2AccountId = make(map[string]string)
|
store.PeerKeyID2AccountID = make(map[string]string)
|
||||||
store.UserId2AccountId = make(map[string]string)
|
store.UserID2AccountID = make(map[string]string)
|
||||||
store.PrivateDomain2AccountId = make(map[string]string)
|
store.PrivateDomain2AccountID = make(map[string]string)
|
||||||
store.PeerKeyId2SrcRulesId = make(map[string]map[string]struct{})
|
|
||||||
store.PeerKeyId2DstRulesId = make(map[string]map[string]struct{})
|
|
||||||
store.PeerKeyID2RouteIDs = make(map[string]map[string]struct{})
|
|
||||||
store.AccountPrefix2RouteIDs = make(map[string]map[string][]string)
|
|
||||||
|
|
||||||
for accountId, account := range store.Accounts {
|
for accountID, account := range store.Accounts {
|
||||||
for setupKeyId := range account.SetupKeys {
|
for setupKeyId := range account.SetupKeys {
|
||||||
store.SetupKeyId2AccountId[strings.ToUpper(setupKeyId)] = accountId
|
store.SetupKeyID2AccountID[strings.ToUpper(setupKeyId)] = accountID
|
||||||
}
|
|
||||||
for _, rule := range account.Rules {
|
|
||||||
for _, groupID := range rule.Source {
|
|
||||||
if group, ok := account.Groups[groupID]; ok {
|
|
||||||
for _, peerID := range group.Peers {
|
|
||||||
rules := store.PeerKeyId2SrcRulesId[peerID]
|
|
||||||
if rules == nil {
|
|
||||||
rules = map[string]struct{}{}
|
|
||||||
store.PeerKeyId2SrcRulesId[peerID] = rules
|
|
||||||
}
|
|
||||||
rules[rule.ID] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, groupID := range rule.Destination {
|
|
||||||
if group, ok := account.Groups[groupID]; ok {
|
|
||||||
for _, peerID := range group.Peers {
|
|
||||||
rules := store.PeerKeyId2DstRulesId[peerID]
|
|
||||||
if rules == nil {
|
|
||||||
rules = map[string]struct{}{}
|
|
||||||
store.PeerKeyId2DstRulesId[peerID] = rules
|
|
||||||
}
|
|
||||||
rules[rule.ID] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, peer := range account.Peers {
|
for _, peer := range account.Peers {
|
||||||
store.PeerKeyId2AccountId[peer.Key] = accountId
|
store.PeerKeyID2AccountID[peer.Key] = accountID
|
||||||
|
// reset all peers to status = Disconnected
|
||||||
|
if peer.Status != nil && peer.Status.Connected {
|
||||||
|
peer.Status.Connected = false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for _, user := range account.Users {
|
for _, user := range account.Users {
|
||||||
store.UserId2AccountId[user.Id] = accountId
|
store.UserID2AccountID[user.Id] = accountID
|
||||||
}
|
}
|
||||||
for _, user := range account.Users {
|
for _, user := range account.Users {
|
||||||
store.UserId2AccountId[user.Id] = accountId
|
store.UserID2AccountID[user.Id] = accountID
|
||||||
}
|
|
||||||
for _, route := range account.Routes {
|
|
||||||
if route.Peer == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if store.PeerKeyID2RouteIDs[route.Peer] == nil {
|
|
||||||
store.PeerKeyID2RouteIDs[route.Peer] = make(map[string]struct{})
|
|
||||||
}
|
|
||||||
store.PeerKeyID2RouteIDs[route.Peer][route.ID] = struct{}{}
|
|
||||||
if store.AccountPrefix2RouteIDs[account.Id] == nil {
|
|
||||||
store.AccountPrefix2RouteIDs[account.Id] = make(map[string][]string)
|
|
||||||
}
|
|
||||||
if _, ok := store.AccountPrefix2RouteIDs[account.Id][route.Network.String()]; !ok {
|
|
||||||
store.AccountPrefix2RouteIDs[account.Id][route.Network.String()] = make([]string, 0)
|
|
||||||
}
|
|
||||||
store.AccountPrefix2RouteIDs[account.Id][route.Network.String()] = append(
|
|
||||||
store.AccountPrefix2RouteIDs[account.Id][route.Network.String()],
|
|
||||||
route.ID,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Domain != "" && account.DomainCategory == PrivateCategory &&
|
if account.Domain != "" && account.DomainCategory == PrivateCategory &&
|
||||||
account.IsDomainPrimaryAccount {
|
account.IsDomainPrimaryAccount {
|
||||||
store.PrivateDomain2AccountId[account.Domain] = accountId
|
store.PrivateDomain2AccountID[account.Domain] = accountID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// we need this persist to apply changes we made to account.Peers (we set them to Disconnected)
|
||||||
|
err = store.persist(store.storeFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return store, nil
|
return store, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// persist persists account data to a file
|
// persist account data to a file
|
||||||
// It is recommended to call it with locking FileStore.mux
|
// It is recommended to call it with locking FileStore.mux
|
||||||
func (s *FileStore) persist(file string) error {
|
func (s *FileStore) persist(file string) error {
|
||||||
return util.WriteJson(file, s)
|
return util.WriteJson(file, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SavePeer saves updated peer
|
|
||||||
func (s *FileStore) SavePeer(accountId string, peer *Peer) error {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.GetAccount(accountId)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// if it is new peer, add it to default 'All' group
|
|
||||||
allGroup, err := account.GetGroupAll()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ind := -1
|
|
||||||
for i, pid := range allGroup.Peers {
|
|
||||||
if pid == peer.Key {
|
|
||||||
ind = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ind < 0 {
|
|
||||||
allGroup.Peers = append(allGroup.Peers, peer.Key)
|
|
||||||
}
|
|
||||||
|
|
||||||
account.Peers[peer.Key] = peer
|
|
||||||
return s.persist(s.storeFile)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeletePeer deletes peer from the Store
|
|
||||||
func (s *FileStore) DeletePeer(accountId string, peerKey string) (*Peer, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.GetAccount(accountId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
peer := account.Peers[peerKey]
|
|
||||||
if peer == nil {
|
|
||||||
return nil, status.Errorf(codes.NotFound, "peer not found")
|
|
||||||
}
|
|
||||||
peerRoutes := s.PeerKeyID2RouteIDs[peerKey]
|
|
||||||
delete(account.Peers, peerKey)
|
|
||||||
delete(s.PeerKeyId2AccountId, peerKey)
|
|
||||||
delete(s.PeerKeyId2DstRulesId, peerKey)
|
|
||||||
delete(s.PeerKeyId2SrcRulesId, peerKey)
|
|
||||||
delete(s.PeerKeyID2RouteIDs, peerKey)
|
|
||||||
|
|
||||||
// cleanup groups
|
|
||||||
for _, g := range account.Groups {
|
|
||||||
var peers []string
|
|
||||||
for _, p := range g.Peers {
|
|
||||||
if p != peerKey {
|
|
||||||
peers = append(peers, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
g.Peers = peers
|
|
||||||
}
|
|
||||||
|
|
||||||
for routeID := range peerRoutes {
|
|
||||||
account.Routes[routeID].Enabled = false
|
|
||||||
account.Routes[routeID].Peer = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.persist(s.storeFile)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return peer, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPeer returns a peer from a Store
|
|
||||||
func (s *FileStore) GetPeer(peerKey string) (*Peer, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
accountId, accountIdFound := s.PeerKeyId2AccountId[peerKey]
|
|
||||||
if !accountIdFound {
|
|
||||||
return nil, status.Errorf(codes.NotFound, "peer not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
account, err := s.GetAccount(accountId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if peer, ok := account.Peers[peerKey]; ok {
|
|
||||||
return peer, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, status.Errorf(codes.NotFound, "peer not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveAccount updates an existing account or adds a new one
|
// SaveAccount updates an existing account or adds a new one
|
||||||
func (s *FileStore) SaveAccount(account *Account) error {
|
func (s *FileStore) SaveAccount(account *Account) error {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
accountCopy := account.Copy()
|
||||||
|
|
||||||
// todo will override, handle existing keys
|
// todo will override, handle existing keys
|
||||||
s.Accounts[account.Id] = account
|
s.Accounts[accountCopy.Id] = accountCopy
|
||||||
|
|
||||||
// todo check that account.Id and keyId are not exist already
|
// todo check that account.Id and keyId are not exist already
|
||||||
// because if keyId exists for other accounts this can be bad
|
// because if keyId exists for other accounts this can be bad
|
||||||
for keyId := range account.SetupKeys {
|
for keyID := range accountCopy.SetupKeys {
|
||||||
s.SetupKeyId2AccountId[strings.ToUpper(keyId)] = account.Id
|
s.SetupKeyID2AccountID[strings.ToUpper(keyID)] = accountCopy.Id
|
||||||
}
|
}
|
||||||
|
|
||||||
// enforce peer to account index and delete peer to route indexes for rebuild
|
// enforce peer to account index and delete peer to route indexes for rebuild
|
||||||
for _, peer := range account.Peers {
|
for _, peer := range accountCopy.Peers {
|
||||||
s.PeerKeyId2AccountId[peer.Key] = account.Id
|
s.PeerKeyID2AccountID[peer.Key] = accountCopy.Id
|
||||||
delete(s.PeerKeyID2RouteIDs, peer.Key)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(s.AccountPrefix2RouteIDs, account.Id)
|
for _, user := range accountCopy.Users {
|
||||||
|
s.UserID2AccountID[user.Id] = accountCopy.Id
|
||||||
// remove all peers related to account from rules indexes
|
|
||||||
cleanIDs := make([]string, 0)
|
|
||||||
for key := range s.PeerKeyId2SrcRulesId {
|
|
||||||
if accountID, ok := s.PeerKeyId2AccountId[key]; ok && accountID == account.Id {
|
|
||||||
cleanIDs = append(cleanIDs, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, key := range cleanIDs {
|
|
||||||
delete(s.PeerKeyId2SrcRulesId, key)
|
|
||||||
}
|
|
||||||
cleanIDs = cleanIDs[:0]
|
|
||||||
for key := range s.PeerKeyId2DstRulesId {
|
|
||||||
if accountID, ok := s.PeerKeyId2AccountId[key]; ok && accountID == account.Id {
|
|
||||||
cleanIDs = append(cleanIDs, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, key := range cleanIDs {
|
|
||||||
delete(s.PeerKeyId2DstRulesId, key)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// rebuild rule indexes
|
if accountCopy.DomainCategory == PrivateCategory && accountCopy.IsDomainPrimaryAccount {
|
||||||
for _, rule := range account.Rules {
|
s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id
|
||||||
for _, gid := range rule.Source {
|
|
||||||
g, ok := account.Groups[gid]
|
|
||||||
if !ok {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
for _, pid := range g.Peers {
|
|
||||||
rules := s.PeerKeyId2SrcRulesId[pid]
|
|
||||||
if rules == nil {
|
|
||||||
rules = map[string]struct{}{}
|
|
||||||
s.PeerKeyId2SrcRulesId[pid] = rules
|
|
||||||
}
|
|
||||||
rules[rule.ID] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, gid := range rule.Destination {
|
|
||||||
g, ok := account.Groups[gid]
|
|
||||||
if !ok {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
for _, pid := range g.Peers {
|
|
||||||
rules := s.PeerKeyId2DstRulesId[pid]
|
|
||||||
if rules == nil {
|
|
||||||
rules = map[string]struct{}{}
|
|
||||||
s.PeerKeyId2DstRulesId[pid] = rules
|
|
||||||
}
|
|
||||||
rules[rule.ID] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, route := range account.Routes {
|
|
||||||
if route.Peer == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if s.PeerKeyID2RouteIDs[route.Peer] == nil {
|
|
||||||
s.PeerKeyID2RouteIDs[route.Peer] = make(map[string]struct{})
|
|
||||||
}
|
|
||||||
s.PeerKeyID2RouteIDs[route.Peer][route.ID] = struct{}{}
|
|
||||||
if s.AccountPrefix2RouteIDs[account.Id] == nil {
|
|
||||||
s.AccountPrefix2RouteIDs[account.Id] = make(map[string][]string)
|
|
||||||
}
|
|
||||||
if _, ok := s.AccountPrefix2RouteIDs[account.Id][route.Network.String()]; !ok {
|
|
||||||
s.AccountPrefix2RouteIDs[account.Id][route.Network.String()] = make([]string, 0)
|
|
||||||
}
|
|
||||||
s.AccountPrefix2RouteIDs[account.Id][route.Network.String()] = append(
|
|
||||||
s.AccountPrefix2RouteIDs[account.Id][route.Network.String()],
|
|
||||||
route.ID,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, user := range account.Users {
|
|
||||||
s.UserId2AccountId[user.Id] = account.Id
|
|
||||||
}
|
|
||||||
|
|
||||||
if account.DomainCategory == PrivateCategory && account.IsDomainPrimaryAccount {
|
|
||||||
s.PrivateDomain2AccountId[account.Domain] = account.Id
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.persist(s.storeFile)
|
return s.persist(s.storeFile)
|
||||||
@ -365,53 +145,25 @@ func (s *FileStore) SaveAccount(account *Account) error {
|
|||||||
|
|
||||||
// GetAccountByPrivateDomain returns account by private domain
|
// GetAccountByPrivateDomain returns account by private domain
|
||||||
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
||||||
accountId, accountIdFound := s.PrivateDomain2AccountId[strings.ToLower(domain)]
|
accountID, accountIDFound := s.PrivateDomain2AccountID[strings.ToLower(domain)]
|
||||||
if !accountIdFound {
|
if !accountIDFound {
|
||||||
return nil, status.Errorf(
|
return nil, status.Errorf(
|
||||||
codes.NotFound,
|
codes.NotFound,
|
||||||
"provided domain is not registered or is not private",
|
"account not found: provided domain is not registered or is not private",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := s.GetAccount(accountId)
|
return s.GetAccount(accountID)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return account, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountBySetupKey returns account by setup key id
|
// GetAccountBySetupKey returns account by setup key id
|
||||||
func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
||||||
accountId, accountIdFound := s.SetupKeyId2AccountId[strings.ToUpper(setupKey)]
|
accountID, accountIDFound := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
|
||||||
if !accountIdFound {
|
if !accountIDFound {
|
||||||
return nil, status.Errorf(codes.NotFound, "provided setup key doesn't exists")
|
return nil, status.Errorf(codes.NotFound, "account not found: provided setup key doesn't exists")
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := s.GetAccount(accountId)
|
return s.GetAccount(accountID)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return account, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountPeers returns account peers
|
|
||||||
func (s *FileStore) GetAccountPeers(accountId string) ([]*Peer, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.GetAccount(accountId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var peers []*Peer
|
|
||||||
for _, peer := range account.Peers {
|
|
||||||
peers = append(peers, peer)
|
|
||||||
}
|
|
||||||
|
|
||||||
return peers, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllAccounts returns all accounts
|
// GetAllAccounts returns all accounts
|
||||||
@ -426,148 +178,39 @@ func (s *FileStore) GetAllAccounts() (all []*Account) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAccount returns an account for id
|
// GetAccount returns an account for id
|
||||||
func (s *FileStore) GetAccount(accountId string) (*Account, error) {
|
func (s *FileStore) GetAccount(accountID string) (*Account, error) {
|
||||||
account, accountFound := s.Accounts[accountId]
|
account, accountFound := s.Accounts[accountID]
|
||||||
if !accountFound {
|
if !accountFound {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
return account, nil
|
return account.Copy(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserAccount returns a user account
|
// GetAccountByUser returns a user account
|
||||||
func (s *FileStore) GetUserAccount(userId string) (*Account, error) {
|
func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
accountId, accountIdFound := s.UserId2AccountId[userId]
|
accountID, accountIDFound := s.UserID2AccountID[userID]
|
||||||
if !accountIdFound {
|
if !accountIDFound {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.GetAccount(accountId)
|
return s.GetAccount(accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileStore) getPeerAccount(peerKey string) (*Account, error) {
|
// GetAccountByPeerPubKey returns an account for a given peer WireGuard public key
|
||||||
accountId, accountIdFound := s.PeerKeyId2AccountId[peerKey]
|
func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||||
if !accountIdFound {
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
|
accountID, accountIDFound := s.PeerKeyID2AccountID[peerKey]
|
||||||
|
if !accountIDFound {
|
||||||
return nil, status.Errorf(codes.NotFound, "Provided peer key doesn't exists %s", peerKey)
|
return nil, status.Errorf(codes.NotFound, "Provided peer key doesn't exists %s", peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.GetAccount(accountId)
|
return s.GetAccount(accountID)
|
||||||
}
|
|
||||||
|
|
||||||
// GetPeerAccount returns user account if exists
|
|
||||||
func (s *FileStore) GetPeerAccount(peerKey string) (*Account, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
return s.getPeerAccount(peerKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPeerSrcRules return list of source rules for peer
|
|
||||||
func (s *FileStore) GetPeerSrcRules(accountId, peerKey string) ([]*Rule, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.GetAccount(accountId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleIDs, ok := s.PeerKeyId2SrcRulesId[peerKey]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("no rules for peer: %v", ruleIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
rules := []*Rule{}
|
|
||||||
for id := range ruleIDs {
|
|
||||||
rule, ok := account.Rules[id]
|
|
||||||
if ok {
|
|
||||||
rules = append(rules, rule)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return rules, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPeerDstRules return list of destination rules for peer
|
|
||||||
func (s *FileStore) GetPeerDstRules(accountId, peerKey string) ([]*Rule, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.GetAccount(accountId)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleIDs, ok := s.PeerKeyId2DstRulesId[peerKey]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("no rules for peer: %v", ruleIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
rules := []*Rule{}
|
|
||||||
for id := range ruleIDs {
|
|
||||||
rule, ok := account.Rules[id]
|
|
||||||
if ok {
|
|
||||||
rules = append(rules, rule)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return rules, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPeerRoutes return list of routes for peer
|
|
||||||
func (s *FileStore) GetPeerRoutes(peerKey string) ([]*route.Route, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.getPeerAccount(peerKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var routes []*route.Route
|
|
||||||
|
|
||||||
routeIDs, ok := s.PeerKeyID2RouteIDs[peerKey]
|
|
||||||
if !ok {
|
|
||||||
return routes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for id := range routeIDs {
|
|
||||||
route, found := account.Routes[id]
|
|
||||||
if found {
|
|
||||||
routes = append(routes, route)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return routes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRoutesByPrefix return list of routes by account and route prefix
|
|
||||||
func (s *FileStore) GetRoutesByPrefix(accountID string, prefix netip.Prefix) ([]*route.Route, error) {
|
|
||||||
s.mux.Lock()
|
|
||||||
defer s.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := s.GetAccount(accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
routeIDs, ok := s.AccountPrefix2RouteIDs[accountID][prefix.String()]
|
|
||||||
if !ok {
|
|
||||||
return nil, status.Errorf(codes.NotFound, "no routes for prefix: %v", prefix.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
var routes []*route.Route
|
|
||||||
for _, id := range routeIDs {
|
|
||||||
route, found := account.Routes[id]
|
|
||||||
if found {
|
|
||||||
routes = append(routes, route)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return routes, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetInstallationID returns the installation ID from the store
|
// GetInstallationID returns the installation ID from the store
|
||||||
|
@ -2,6 +2,7 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"net"
|
"net"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@ -9,6 +10,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type accounts struct {
|
||||||
|
Accounts map[string]*Account
|
||||||
|
}
|
||||||
|
|
||||||
func TestNewStore(t *testing.T) {
|
func TestNewStore(t *testing.T) {
|
||||||
store := newStore(t)
|
store := newStore(t)
|
||||||
|
|
||||||
@ -16,16 +21,16 @@ func TestNewStore(t *testing.T) {
|
|||||||
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
||||||
}
|
}
|
||||||
|
|
||||||
if store.SetupKeyId2AccountId == nil || len(store.SetupKeyId2AccountId) != 0 {
|
if store.SetupKeyID2AccountID == nil || len(store.SetupKeyID2AccountID) != 0 {
|
||||||
t.Errorf("expected to create a new empty SetupKeyId2AccountId map when creating a new FileStore")
|
t.Errorf("expected to create a new empty SetupKeyID2AccountID map when creating a new FileStore")
|
||||||
}
|
}
|
||||||
|
|
||||||
if store.PeerKeyId2AccountId == nil || len(store.PeerKeyId2AccountId) != 0 {
|
if store.PeerKeyID2AccountID == nil || len(store.PeerKeyID2AccountID) != 0 {
|
||||||
t.Errorf("expected to create a new empty PeerKeyId2AccountId map when creating a new FileStore")
|
t.Errorf("expected to create a new empty PeerKeyID2AccountID map when creating a new FileStore")
|
||||||
}
|
}
|
||||||
|
|
||||||
if store.UserId2AccountId == nil || len(store.UserId2AccountId) != 0 {
|
if store.UserID2AccountID == nil || len(store.UserID2AccountID) != 0 {
|
||||||
t.Errorf("expected to create a new empty UserId2AccountId map when creating a new FileStore")
|
t.Errorf("expected to create a new empty UserID2AccountID map when creating a new FileStore")
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -55,16 +60,16 @@ func TestSaveAccount(t *testing.T) {
|
|||||||
t.Errorf("expecting Account to be stored after SaveAccount()")
|
t.Errorf("expecting Account to be stored after SaveAccount()")
|
||||||
}
|
}
|
||||||
|
|
||||||
if store.PeerKeyId2AccountId["peerkey"] == "" {
|
if store.PeerKeyID2AccountID["peerkey"] == "" {
|
||||||
t.Errorf("expecting PeerKeyId2AccountId index updated after SaveAccount()")
|
t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount()")
|
||||||
}
|
}
|
||||||
|
|
||||||
if store.UserId2AccountId["testuser"] == "" {
|
if store.UserID2AccountID["testuser"] == "" {
|
||||||
t.Errorf("expecting UserId2AccountId index updated after SaveAccount()")
|
t.Errorf("expecting UserID2AccountID index updated after SaveAccount()")
|
||||||
}
|
}
|
||||||
|
|
||||||
if store.SetupKeyId2AccountId[setupKey.Key] == "" {
|
if store.SetupKeyID2AccountID[setupKey.Key] == "" {
|
||||||
t.Errorf("expecting SetupKeyId2AccountId index updated after SaveAccount()")
|
t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount()")
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -141,11 +146,11 @@ func TestRestore(t *testing.T) {
|
|||||||
|
|
||||||
require.NotNil(t, account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"], "failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
require.NotNil(t, account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"], "failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
|
||||||
|
|
||||||
require.Len(t, store.UserId2AccountId, 2, "failed to restore a FileStore wrong UserId2AccountId mapping length")
|
require.Len(t, store.UserID2AccountID, 2, "failed to restore a FileStore wrong UserID2AccountID mapping length")
|
||||||
|
|
||||||
require.Len(t, store.SetupKeyId2AccountId, 1, "failed to restore a FileStore wrong SetupKeyId2AccountId mapping length")
|
require.Len(t, store.SetupKeyID2AccountID, 1, "failed to restore a FileStore wrong SetupKeyID2AccountID mapping length")
|
||||||
|
|
||||||
require.Len(t, store.PrivateDomain2AccountId, 1, "failed to restore a FileStore wrong PrivateDomain2AccountId mapping length")
|
require.Len(t, store.PrivateDomain2AccountID, 1, "failed to restore a FileStore wrong PrivateDomain2AccountID mapping length")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetAccountByPrivateDomain(t *testing.T) {
|
func TestGetAccountByPrivateDomain(t *testing.T) {
|
||||||
@ -171,6 +176,48 @@ func TestGetAccountByPrivateDomain(t *testing.T) {
|
|||||||
require.Error(t, err, "should return error on domain lookup")
|
require.Error(t, err, "should return error on domain lookup")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFileStore_GetAccount(t *testing.T) {
|
||||||
|
storeDir := t.TempDir()
|
||||||
|
storeFile := filepath.Join(storeDir, "store.json")
|
||||||
|
err := util.CopyFileContents("testdata/store.json", storeFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
accounts := &accounts{}
|
||||||
|
_, err = util.ReadJson(storeFile, accounts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
store, err := NewStore(storeDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
|
||||||
|
if expected == nil {
|
||||||
|
t.Fatalf("expected account doesn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
account, err := store.GetAccount(expected.Id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expected.IsDomainPrimaryAccount, account.IsDomainPrimaryAccount)
|
||||||
|
assert.Equal(t, expected.DomainCategory, account.DomainCategory)
|
||||||
|
assert.Equal(t, expected.Domain, account.Domain)
|
||||||
|
assert.Equal(t, expected.CreatedBy, account.CreatedBy)
|
||||||
|
assert.Equal(t, expected.Network.Id, account.Network.Id)
|
||||||
|
assert.Len(t, account.Peers, len(expected.Peers))
|
||||||
|
assert.Len(t, account.Users, len(expected.Users))
|
||||||
|
assert.Len(t, account.SetupKeys, len(expected.SetupKeys))
|
||||||
|
assert.Len(t, account.Rules, len(expected.Rules))
|
||||||
|
assert.Len(t, account.Routes, len(expected.Routes))
|
||||||
|
assert.Len(t, account.NameServerGroups, len(expected.NameServerGroups))
|
||||||
|
}
|
||||||
|
|
||||||
func newStore(t *testing.T) *FileStore {
|
func newStore(t *testing.T) *FileStore {
|
||||||
store, err := NewStore(t.TempDir())
|
store, err := NewStore(t.TempDir())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -21,8 +21,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var TestPeers = map[string]*server.Peer{
|
var TestPeers = map[string]*server.Peer{
|
||||||
"A": &server.Peer{Key: "A", IP: net.ParseIP("100.100.100.100")},
|
"A": {Key: "A", IP: net.ParseIP("100.100.100.100")},
|
||||||
"B": &server.Peer{Key: "B", IP: net.ParseIP("200.200.200.200")},
|
"B": {Key: "B", IP: net.ParseIP("200.200.200.200")},
|
||||||
}
|
}
|
||||||
|
|
||||||
func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
|
func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
|
||||||
|
@ -112,7 +112,7 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// UpdateRouteHandler handles update to a route identified by a given ID
|
// UpdateRouteHandler handles update to a route identified by a given ID
|
||||||
func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
@ -125,7 +125,7 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = h.accountManager.GetRoute(account.Id, routeID, "")
|
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("couldn't find route for ID %s", routeID), http.StatusNotFound)
|
http.Error(w, fmt.Sprintf("couldn't find route for ID %s", routeID), http.StatusNotFound)
|
||||||
return
|
return
|
||||||
@ -185,7 +185,7 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// PatchRouteHandler handles patch updates to a route identified by a given ID
|
// PatchRouteHandler handles patch updates to a route identified by a given ID
|
||||||
func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
@ -198,7 +198,7 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = h.accountManager.GetRoute(account.Id, routeID, "")
|
_, err = h.accountManager.GetRoute(account.Id, routeID, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
http.Error(w, fmt.Sprintf("couldn't find route ID %s", routeID), http.StatusNotFound)
|
http.Error(w, fmt.Sprintf("couldn't find route ID %s", routeID), http.StatusNotFound)
|
||||||
|
@ -1,13 +0,0 @@
|
|||||||
## Migration from Store v2 to Store v2
|
|
||||||
|
|
||||||
Previously Account.Id was an Auth0 user id.
|
|
||||||
Conversion moves user id to Account.CreatedBy and generates a new Account.Id using xid.
|
|
||||||
It also adds a User with id = old Account.Id with a role Admin.
|
|
||||||
|
|
||||||
To start a conversion simply run the command below providing your current Wiretrustee Management datadir (where store.json file is located)
|
|
||||||
and a new data directory location (where a converted store.js will be stored):
|
|
||||||
```shell
|
|
||||||
./migration --oldDir /var/wiretrustee/datadir --newDir /var/wiretrustee/newdatadir/
|
|
||||||
```
|
|
||||||
|
|
||||||
Afterwards you can run the Management service providing ```/var/wiretrustee/newdatadir/ ``` as a datadir.
|
|
@ -1,56 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"flag"
|
|
||||||
"fmt"
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/rs/xid"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
|
|
||||||
oldDir := flag.String("oldDir", "old store directory", "/var/wiretrustee/datadir")
|
|
||||||
newDir := flag.String("newDir", "new store directory", "/var/wiretrustee/newdatadir")
|
|
||||||
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
oldStore, err := server.NewStore(*oldDir)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
newStore, err := server.NewStore(*newDir)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = Convert(oldStore, newStore)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("successfully converted")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert converts old store ato a new store
|
|
||||||
// Previously Account.Id was an Auth0 user id
|
|
||||||
// Conversion moved user id to Account.CreatedBy and generated a new Account.Id using xid
|
|
||||||
// It also adds a User with id = old Account.Id with a role Admin
|
|
||||||
func Convert(oldStore *server.FileStore, newStore *server.FileStore) error {
|
|
||||||
for _, account := range oldStore.Accounts {
|
|
||||||
accountCopy := account.Copy()
|
|
||||||
accountCopy.Id = xid.New().String()
|
|
||||||
accountCopy.CreatedBy = account.Id
|
|
||||||
accountCopy.Users[account.Id] = &server.User{
|
|
||||||
Id: account.Id,
|
|
||||||
Role: server.UserRoleAdmin,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := newStore.SaveAccount(accountCopy)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -1,76 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestConvertAccounts(t *testing.T) {
|
|
||||||
|
|
||||||
storeDir := t.TempDir()
|
|
||||||
|
|
||||||
err := util.CopyFileContents("../testdata/storev1.json", filepath.Join(storeDir, "store.json"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
store, err := server.NewStore(storeDir)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
convertedStore, err := server.NewStore(filepath.Join(storeDir, "converted"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = Convert(store, convertedStore)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(store.Accounts) != len(convertedStore.Accounts) {
|
|
||||||
t.Errorf("expecting the same number of accounts after conversion")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, account := range store.Accounts {
|
|
||||||
convertedAccount, err := convertedStore.GetUserAccount(account.Id)
|
|
||||||
if err != nil || convertedAccount == nil {
|
|
||||||
t.Errorf("expecting Account %s to be converted", account.Id)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if convertedAccount.CreatedBy != account.Id {
|
|
||||||
t.Errorf("expecting converted Account.CreatedBy field to be equal to the old Account.Id")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if convertedAccount.Id == account.Id {
|
|
||||||
t.Errorf("expecting converted Account.Id to be different from Account.Id")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(convertedAccount.Users) != 1 {
|
|
||||||
t.Errorf("expecting converted Account.Users to be of size 1")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
user := convertedAccount.Users[account.Id]
|
|
||||||
if user == nil {
|
|
||||||
t.Errorf("expecting to find a user in converted Account.Users")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if user.Role != server.UserRoleAdmin {
|
|
||||||
t.Errorf("expecting to find a user in converted Account.Users with a role Admin")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for peerId := range account.Peers {
|
|
||||||
convertedPeer := convertedAccount.Peers[peerId]
|
|
||||||
if convertedPeer == nil {
|
|
||||||
t.Errorf("expecting Account Peer of StoreV1 to be found in StoreV2")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -22,7 +22,6 @@ type MockAccountManager struct {
|
|||||||
GetPeerFunc func(peerKey string) (*server.Peer, error)
|
GetPeerFunc func(peerKey string) (*server.Peer, error)
|
||||||
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
|
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error)
|
||||||
MarkPeerConnectedFunc func(peerKey string, connected bool) error
|
MarkPeerConnectedFunc func(peerKey string, connected bool) error
|
||||||
RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error)
|
|
||||||
DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error)
|
DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error)
|
||||||
GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error)
|
GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error)
|
||||||
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
||||||
@ -72,6 +71,14 @@ func (am *MockAccountManager) GetUsersFromAccount(accountID string, userID strin
|
|||||||
return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount is not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeletePeer mock implementation of DeletePeer from server.AccountManager interface
|
||||||
|
func (am *MockAccountManager) DeletePeer(accountId string, peerKey string) (*server.Peer, error) {
|
||||||
|
if am.DeletePeerFunc != nil {
|
||||||
|
return am.DeletePeerFunc(accountId, peerKey)
|
||||||
|
}
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
|
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetOrCreateAccountByUser(
|
func (am *MockAccountManager) GetOrCreateAccountByUser(
|
||||||
userId, domain string,
|
userId, domain string,
|
||||||
@ -152,26 +159,6 @@ func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool)
|
|||||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// RenamePeer mock implementation of RenamePeer from server.AccountManager interface
|
|
||||||
func (am *MockAccountManager) RenamePeer(
|
|
||||||
accountId string,
|
|
||||||
peerKey string,
|
|
||||||
newName string,
|
|
||||||
) (*server.Peer, error) {
|
|
||||||
if am.RenamePeerFunc != nil {
|
|
||||||
return am.RenamePeerFunc(accountId, peerKey, newName)
|
|
||||||
}
|
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method RenamePeer is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeletePeer mock implementation of DeletePeer from server.AccountManager interface
|
|
||||||
func (am *MockAccountManager) DeletePeer(accountId string, peerKey string) (*server.Peer, error) {
|
|
||||||
if am.DeletePeerFunc != nil {
|
|
||||||
return am.DeletePeerFunc(accountId, peerKey)
|
|
||||||
}
|
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPeerByIP mock implementation of GetPeerByIP from server.AccountManager interface
|
// GetPeerByIP mock implementation of GetPeerByIP from server.AccountManager interface
|
||||||
func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*server.Peer, error) {
|
func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*server.Peer, error) {
|
||||||
if am.GetPeerByIPFunc != nil {
|
if am.GetPeerByIPFunc != nil {
|
||||||
|
@ -635,6 +635,11 @@ func TestSaveNameServerGroup(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account, err = am.GetAccountById(account.Id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
savedNSGroup, saved := account.NameServerGroups[testCase.expectedNSGroup.ID]
|
savedNSGroup, saved := account.NameServerGroups[testCase.expectedNSGroup.ID]
|
||||||
require.True(t, saved)
|
require.True(t, saved)
|
||||||
|
|
||||||
|
@ -68,17 +68,17 @@ func (p *Peer) Copy() *Peer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPeer returns a peer from a Store
|
// GetPeer looks up peer by its public WireGuard key
|
||||||
func (am *DefaultAccountManager) GetPeer(peerKey string) (*Peer, error) {
|
func (am *DefaultAccountManager) GetPeer(peerPubKey string) (*Peer, error) {
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
defer am.mux.Unlock()
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
peer, err := am.Store.GetPeer(peerKey)
|
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return peer, nil
|
return account.FindPeerByPubKey(peerPubKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
|
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
|
||||||
@ -109,24 +109,26 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||||
func (am *DefaultAccountManager) MarkPeerConnected(peerKey string, connected bool) error {
|
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool) error {
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
defer am.mux.Unlock()
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
peer, err := am.Store.GetPeer(peerKey)
|
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetPeerAccount(peerKey)
|
peer, err := account.FindPeerByPubKey(peerPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
peerCopy := peer.Copy()
|
peer.Status.LastSeen = time.Now()
|
||||||
peerCopy.Status.LastSeen = time.Now()
|
peer.Status.Connected = connected
|
||||||
peerCopy.Status.Connected = connected
|
|
||||||
err = am.Store.SavePeer(account.Id, peerCopy)
|
account.UpdatePeer(peer)
|
||||||
|
|
||||||
|
err = am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -143,18 +145,20 @@ func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Pe
|
|||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, err := am.Store.GetPeer(update.Key)
|
//TODO Peer.ID migration: we will need to replace search by ID here
|
||||||
|
peer, err := account.FindPeerByPubKey(update.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
peerCopy := peer.Copy()
|
|
||||||
if peer.Name != "" {
|
if peer.Name != "" {
|
||||||
peerCopy.Name = update.Name
|
peer.Name = update.Name
|
||||||
}
|
}
|
||||||
peerCopy.SSHEnabled = update.SSHEnabled
|
peer.SSHEnabled = update.SSHEnabled
|
||||||
|
|
||||||
err = am.Store.SavePeer(accountID, peerCopy)
|
account.UpdatePeer(peer)
|
||||||
|
|
||||||
|
err = am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -164,66 +168,32 @@ func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Pe
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return peerCopy, nil
|
return peer, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RenamePeer changes peer's name
|
// DeletePeer removes peer from the account by its IP
|
||||||
func (am *DefaultAccountManager) RenamePeer(
|
func (am *DefaultAccountManager) DeletePeer(accountID string, peerPubKey string) (*Peer, error) {
|
||||||
accountId string,
|
|
||||||
peerKey string,
|
|
||||||
newName string,
|
|
||||||
) (*Peer, error) {
|
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
defer am.mux.Unlock()
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
peer, err := am.Store.GetPeer(peerKey)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
peerCopy := peer.Copy()
|
|
||||||
peerCopy.Name = newName
|
|
||||||
err = am.Store.SavePeer(accountId, peerCopy)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return peerCopy, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeletePeer removes peer from the account by it's IP
|
|
||||||
func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*Peer, error) {
|
|
||||||
am.mux.Lock()
|
|
||||||
defer am.mux.Unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountId)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
// delete peer from groups
|
peer, err := account.FindPeerByPubKey(peerPubKey)
|
||||||
for _, g := range account.Groups {
|
|
||||||
for i, pk := range g.Peers {
|
|
||||||
if pk == peerKey {
|
|
||||||
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
peer, err := am.Store.DeletePeer(accountId, peerKey)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
account.Network.IncSerial()
|
account.DeletePeer(peerPubKey)
|
||||||
|
|
||||||
err = am.Store.SaveAccount(account)
|
err = am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.peersUpdateManager.SendUpdate(peerKey,
|
err = am.peersUpdateManager.SendUpdate(peerPubKey,
|
||||||
&UpdateMessage{
|
&UpdateMessage{
|
||||||
Update: &proto.SyncResponse{
|
Update: &proto.SyncResponse{
|
||||||
// fill those field for backward compatibility
|
// fill those field for backward compatibility
|
||||||
@ -241,20 +211,21 @@ func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (*
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO Peer.ID migration: we will need to replace search by Peer.ID here
|
||||||
if err := am.updateAccountPeers(account); err != nil {
|
if err := am.updateAccountPeers(account); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
am.peersUpdateManager.CloseChannel(peerKey)
|
am.peersUpdateManager.CloseChannel(peerPubKey)
|
||||||
return peer, nil
|
return peer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPeerByIP returns peer by it's IP
|
// GetPeerByIP returns peer by its IP
|
||||||
func (am *DefaultAccountManager) GetPeerByIP(accountId string, peerIP string) (*Peer, error) {
|
func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*Peer, error) {
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
defer am.mux.Unlock()
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(accountId)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "account not found")
|
return nil, status.Errorf(codes.NotFound, "account not found")
|
||||||
}
|
}
|
||||||
@ -269,17 +240,17 @@ func (am *DefaultAccountManager) GetPeerByIP(accountId string, peerIP string) (*
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
||||||
func (am *DefaultAccountManager) GetNetworkMap(peerKey string) (*NetworkMap, error) {
|
func (am *DefaultAccountManager) GetNetworkMap(peerPubKey string) (*NetworkMap, error) {
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
defer am.mux.Unlock()
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetPeerAccount(peerKey)
|
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "Invalid peer key %s", peerKey)
|
return nil, status.Errorf(codes.Internal, "Invalid peer key %s", peerPubKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
aclPeers := am.getPeersByACL(account, peerKey)
|
aclPeers := am.getPeersByACL(account, peerPubKey)
|
||||||
routesUpdate := am.getPeersRoutes(append(aclPeers, account.Peers[peerKey]))
|
routesUpdate := account.GetPeersRoutes(append(aclPeers, account.Peers[peerPubKey]))
|
||||||
|
|
||||||
return &NetworkMap{
|
return &NetworkMap{
|
||||||
Peers: aclPeers,
|
Peers: aclPeers,
|
||||||
@ -289,13 +260,13 @@ func (am *DefaultAccountManager) GetNetworkMap(peerKey string) (*NetworkMap, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetPeerNetwork returns the Network for a given peer
|
// GetPeerNetwork returns the Network for a given peer
|
||||||
func (am *DefaultAccountManager) GetPeerNetwork(peerKey string) (*Network, error) {
|
func (am *DefaultAccountManager) GetPeerNetwork(peerPubKey string) (*Network, error) {
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
defer am.mux.Unlock()
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
account, err := am.Store.GetPeerAccount(peerKey)
|
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "Invalid peer key %s", peerKey)
|
return nil, status.Errorf(codes.Internal, "invalid peer key %s", peerPubKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
return account.Network.Copy(), err
|
return account.Network.Copy(), err
|
||||||
@ -308,11 +279,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(peerKey string) (*Network, error
|
|||||||
// to it. We also add the User ID to the peer metadata to identify registrant.
|
// to it. We also add the User ID to the peer metadata to identify registrant.
|
||||||
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
|
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
|
||||||
// The peer property is just a placeholder for the Peer properties to pass further
|
// The peer property is just a placeholder for the Peer properties to pass further
|
||||||
func (am *DefaultAccountManager) AddPeer(
|
func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *Peer) (*Peer, error) {
|
||||||
setupKey string,
|
|
||||||
userID string,
|
|
||||||
peer *Peer,
|
|
||||||
) (*Peer, error) {
|
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
defer am.mux.Unlock()
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
@ -353,7 +320,7 @@ func (am *DefaultAccountManager) AddPeer(
|
|||||||
groupsToAdd = sk.AutoGroups
|
groupsToAdd = sk.AutoGroups
|
||||||
|
|
||||||
} else if len(userID) != 0 {
|
} else if len(userID) != 0 {
|
||||||
account, err = am.Store.GetUserAccount(userID)
|
account, err = am.Store.GetAccountByUser(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.NotFound, "unable to register peer, unknown user with ID: %s", userID)
|
return nil, status.Errorf(codes.NotFound, "unable to register peer, unknown user with ID: %s", userID)
|
||||||
}
|
}
|
||||||
@ -422,34 +389,34 @@ func (am *DefaultAccountManager) AddPeer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePeerSSHKey updates peer's public SSH key
|
// UpdatePeerSSHKey updates peer's public SSH key
|
||||||
func (am *DefaultAccountManager) UpdatePeerSSHKey(peerKey string, sshKey string) error {
|
func (am *DefaultAccountManager) UpdatePeerSSHKey(peerPubKey string, sshKey string) error {
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
defer am.mux.Unlock()
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
if sshKey == "" {
|
if sshKey == "" {
|
||||||
log.Debugf("empty SSH key provided for peer %s, skipping update", peerKey)
|
log.Debugf("empty SSH key provided for peer %s, skipping update", peerPubKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, err := am.Store.GetPeer(peerKey)
|
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
peer, err := account.FindPeerByPubKey(peerPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.SSHKey == sshKey {
|
if peer.SSHKey == sshKey {
|
||||||
log.Debugf("same SSH key provided for peer %s, skipping update", peerKey)
|
log.Debugf("same SSH key provided for peer %s, skipping update", peerPubKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetPeerAccount(peerKey)
|
peer.SSHKey = sshKey
|
||||||
if err != nil {
|
account.UpdatePeer(peer)
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peerCopy := peer.Copy()
|
err = am.Store.SaveAccount(account)
|
||||||
peerCopy.SSHKey = sshKey
|
|
||||||
|
|
||||||
err = am.Store.SavePeer(account.Id, peerCopy)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -459,29 +426,29 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerKey string, sshKey string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePeerMeta updates peer's system metadata
|
// UpdatePeerMeta updates peer's system metadata
|
||||||
func (am *DefaultAccountManager) UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error {
|
func (am *DefaultAccountManager) UpdatePeerMeta(peerPubKey string, meta PeerSystemMeta) error {
|
||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
defer am.mux.Unlock()
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
peer, err := am.Store.GetPeer(peerKey)
|
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetPeerAccount(peerKey)
|
peer, err := account.FindPeerByPubKey(peerPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
peerCopy := peer.Copy()
|
|
||||||
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
|
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
|
||||||
if meta.UIVersion == "" {
|
if meta.UIVersion == "" {
|
||||||
meta.UIVersion = peerCopy.Meta.UIVersion
|
meta.UIVersion = peer.Meta.UIVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
peerCopy.Meta = meta
|
peer.Meta = meta
|
||||||
|
account.UpdatePeer(peer)
|
||||||
|
|
||||||
err = am.Store.SavePeer(account.Id, peerCopy)
|
err = am.Store.SaveAccount(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -489,17 +456,9 @@ func (am *DefaultAccountManager) UpdatePeerMeta(peerKey string, meta PeerSystemM
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getPeersByACL returns all peers that given peer has access to.
|
// getPeersByACL returns all peers that given peer has access to.
|
||||||
func (am *DefaultAccountManager) getPeersByACL(account *Account, peerKey string) []*Peer {
|
func (am *DefaultAccountManager) getPeersByACL(account *Account, peerPubKey string) []*Peer {
|
||||||
var peers []*Peer
|
var peers []*Peer
|
||||||
srcRules, err := am.Store.GetPeerSrcRules(account.Id, peerKey)
|
srcRules, dstRules := account.GetPeerRules(peerPubKey)
|
||||||
if err != nil {
|
|
||||||
srcRules = []*Rule{}
|
|
||||||
}
|
|
||||||
|
|
||||||
dstRules, err := am.Store.GetPeerDstRules(account.Id, peerKey)
|
|
||||||
if err != nil {
|
|
||||||
dstRules = []*Rule{}
|
|
||||||
}
|
|
||||||
|
|
||||||
groups := map[string]*Group{}
|
groups := map[string]*Group{}
|
||||||
for _, r := range srcRules {
|
for _, r := range srcRules {
|
||||||
@ -542,7 +501,7 @@ func (am *DefaultAccountManager) getPeersByACL(account *Account, peerKey string)
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// exclude original peer
|
// exclude original peer
|
||||||
if _, ok := peersSet[peer.Key]; peer.Key != peerKey && !ok {
|
if _, ok := peersSet[peer.Key]; peer.Key != peerPubKey && !ok {
|
||||||
peersSet[peer.Key] = struct{}{}
|
peersSet[peer.Key] = struct{}{}
|
||||||
peers = append(peers, peer.Copy())
|
peers = append(peers, peer.Copy())
|
||||||
}
|
}
|
||||||
@ -556,18 +515,14 @@ func (am *DefaultAccountManager) getPeersByACL(account *Account, peerKey string)
|
|||||||
// Should be called when changes have to be synced to peers.
|
// Should be called when changes have to be synced to peers.
|
||||||
func (am *DefaultAccountManager) updateAccountPeers(account *Account) error {
|
func (am *DefaultAccountManager) updateAccountPeers(account *Account) error {
|
||||||
// notify other peers of the change
|
// notify other peers of the change
|
||||||
peers, err := am.Store.GetAccountPeers(account.Id)
|
peers := account.GetPeers()
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
network := account.Network.Copy()
|
network := account.Network.Copy()
|
||||||
|
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
aclPeers := am.getPeersByACL(account, peer.Key)
|
aclPeers := am.getPeersByACL(account, peer.Key)
|
||||||
peersUpdate := toRemotePeerConfig(aclPeers)
|
peersUpdate := toRemotePeerConfig(aclPeers)
|
||||||
routesUpdate := toProtocolRoutes(am.getPeersRoutes(append(aclPeers, peer)))
|
routesUpdate := toProtocolRoutes(account.GetPeersRoutes(append(aclPeers, peer)))
|
||||||
err = am.peersUpdateManager.SendUpdate(peer.Key,
|
err := am.peersUpdateManager.SendUpdate(peer.Key,
|
||||||
&UpdateMessage{
|
&UpdateMessage{
|
||||||
Update: &proto.SyncResponse{
|
Update: &proto.SyncResponse{
|
||||||
// fill deprecated fields for backward compatibility
|
// fill deprecated fields for backward compatibility
|
||||||
|
@ -93,7 +93,12 @@ func (am *DefaultAccountManager) checkPrefixPeerExists(accountID, peer string, p
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
routesWithPrefix, err := am.Store.GetRoutesByPrefix(accountID, prefix)
|
account, err := am.Store.GetAccount(accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
routesWithPrefix := account.GetRoutesByPrefix(prefix)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
@ -372,30 +377,6 @@ func toProtocolRoute(route *route.Route) *proto.Route {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) getPeersRoutes(peers []*Peer) []*route.Route {
|
|
||||||
routes := make([]*route.Route, 0)
|
|
||||||
for _, peer := range peers {
|
|
||||||
peerRoutes, err := am.Store.GetPeerRoutes(peer.Key)
|
|
||||||
if err != nil {
|
|
||||||
errorStatus, ok := status.FromError(err)
|
|
||||||
if !ok && errorStatus.Code() != codes.NotFound {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
activeRoutes := make([]*route.Route, 0)
|
|
||||||
for _, pr := range peerRoutes {
|
|
||||||
if pr.Enabled {
|
|
||||||
activeRoutes = append(activeRoutes, pr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(activeRoutes) > 0 {
|
|
||||||
routes = append(routes, activeRoutes...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return routes
|
|
||||||
}
|
|
||||||
|
|
||||||
func toProtocolRoutes(routes []*route.Route) []*proto.Route {
|
func toProtocolRoutes(routes []*route.Route) []*proto.Route {
|
||||||
protoRoutes := make([]*proto.Route, 0)
|
protoRoutes := make([]*proto.Route, 0)
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
|
@ -380,6 +380,11 @@ func TestSaveRoute(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
account, err = am.GetAccountById(account.Id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
savedRoute, saved := account.Routes[testCase.expectedRoute.ID]
|
savedRoute, saved := account.Routes[testCase.expectedRoute.ID]
|
||||||
require.True(t, saved)
|
require.True(t, saved)
|
||||||
|
|
||||||
@ -840,5 +845,5 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return account, nil
|
return am.GetAccountById(accountID)
|
||||||
}
|
}
|
||||||
|
@ -1,26 +1,13 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
"net/netip"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Store interface {
|
type Store interface {
|
||||||
GetPeer(peerKey string) (*Peer, error)
|
|
||||||
DeletePeer(accountId string, peerKey string) (*Peer, error)
|
|
||||||
SavePeer(accountId string, peer *Peer) error
|
|
||||||
GetAllAccounts() []*Account
|
GetAllAccounts() []*Account
|
||||||
GetAccount(accountId string) (*Account, error)
|
GetAccount(accountID string) (*Account, error)
|
||||||
GetUserAccount(userId string) (*Account, error)
|
GetAccountByUser(userID string) (*Account, error)
|
||||||
GetAccountPeers(accountId string) ([]*Peer, error)
|
GetAccountByPeerPubKey(peerKey string) (*Account, error)
|
||||||
GetPeerAccount(peerKey string) (*Account, error)
|
GetAccountBySetupKey(setupKey string) (*Account, error) //todo use key hash later
|
||||||
GetPeerSrcRules(accountId, peerKey string) ([]*Rule, error)
|
|
||||||
GetPeerDstRules(accountId, peerKey string) ([]*Rule, error)
|
|
||||||
GetAccountBySetupKey(setupKey string) (*Account, error)
|
|
||||||
GetAccountByPrivateDomain(domain string) (*Account, error)
|
GetAccountByPrivateDomain(domain string) (*Account, error)
|
||||||
SaveAccount(account *Account) error
|
SaveAccount(account *Account) error
|
||||||
GetPeerRoutes(peerKey string) ([]*route.Route, error)
|
|
||||||
GetRoutesByPrefix(accountID string, prefix netip.Prefix) ([]*route.Route, error)
|
|
||||||
GetInstallationID() string
|
GetInstallationID() string
|
||||||
SaveInstallationID(id string) error
|
SaveInstallationID(id string) error
|
||||||
}
|
}
|
||||||
|
@ -240,7 +240,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string)
|
|||||||
|
|
||||||
lowerDomain := strings.ToLower(domain)
|
lowerDomain := strings.ToLower(domain)
|
||||||
|
|
||||||
account, err := am.Store.GetUserAccount(userId)
|
account, err := am.Store.GetAccountByUser(userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
account, err = am.newAccount(userId, lowerDomain)
|
account, err = am.newAccount(userId, lowerDomain)
|
||||||
@ -275,7 +275,7 @@ func (am *DefaultAccountManager) GetAccountByUser(userId string) (*Account, erro
|
|||||||
am.mux.Lock()
|
am.mux.Lock()
|
||||||
defer am.mux.Unlock()
|
defer am.mux.Unlock()
|
||||||
|
|
||||||
return am.Store.GetUserAccount(userId)
|
return am.Store.GetAccountByUser(userId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsUserAdmin flag for current user authenticated by JWT token
|
// IsUserAdmin flag for current user authenticated by JWT token
|
||||||
|
Loading…
Reference in New Issue
Block a user