netbird/management/server/account.go
Misha Bragin e914adb5cd
Move Login business logic from gRPC API to Accountmanager (#713)
The Management gRPC API has too much business logic 
happening while it has to be in the Account manager.
This also needs to make more requests to the store 
through the account manager.
2023-03-03 18:35:38 +01:00

1356 lines
43 KiB
Go

package server
import (
"context"
"fmt"
"math/rand"
"net"
"net/netip"
"reflect"
"regexp"
"strings"
"sync"
"time"
"github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
)
const (
PublicCategory = "public"
PrivateCategory = "private"
UnknownCategory = "unknown"
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
DefaultPeerLoginExpiration = 24 * time.Hour
)
func cacheEntryExpiration() time.Duration {
r := rand.Intn(int(CacheExpirationMax.Milliseconds()-CacheExpirationMin.Milliseconds())) + int(CacheExpirationMin.Milliseconds())
return time.Duration(r) * time.Millisecond
}
type AccountManager interface {
GetOrCreateAccountByUser(userId, domain string) (*Account, error)
CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
autoGroups []string, usageLimit int, userID string) (*SetupKey, error)
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
CreateUser(accountID, userID string, key *UserInfo) (*UserInfo, error)
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
SaveUser(accountID, userID string, update *User) (*UserInfo, error)
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExists(accountId string) (*bool, error)
GetPeerByKey(peerKey string) (*Peer, error)
GetPeers(accountID, userID string) ([]*Peer, error)
MarkPeerConnected(peerKey string, connected bool) error
DeletePeer(accountID, peerID, userID string) (*Peer, error)
GetPeerByIP(accountId string, peerIP string) (*Peer, error)
UpdatePeer(accountID, userID string, peer *Peer) (*Peer, error)
GetNetworkMap(peerID string) (*NetworkMap, error)
GetPeerNetwork(peerID string) (*Network, error)
AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error)
UpdatePeerSSHKey(peerID string, sshKey string) error
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
GetGroup(accountId, groupID string) (*Group, error)
SaveGroup(accountID, userID string, group *Group) error
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
DeleteGroup(accountId, groupID string) error
ListGroups(accountId string) ([]*Group, error)
GroupAddPeer(accountId, groupID, peerID string) error
GroupDeletePeer(accountId, groupID, peerKey string) error
GroupListPeers(accountId, groupID string) ([]*Peer, error)
GetRule(accountID, ruleID, userID string) (*Rule, error)
SaveRule(accountID, userID string, rule *Rule) error
UpdateRule(accountID string, ruleID string, operations []RuleUpdateOperation) (*Rule, error)
DeleteRule(accountID, ruleID, userID string) error
ListRules(accountID, userID string) ([]*Rule, error)
GetRoute(accountID, routeID, userID string) (*route.Route, error)
CreateRoute(accountID string, prefix, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
SaveRoute(accountID, userID string, route *route.Route) error
UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error)
DeleteRoute(accountID, routeID, userID string) error
ListRoutes(accountID, userID string) ([]*route.Route, error)
GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
DeleteNameServerGroup(accountID, nsGroupID, userID string) error
ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error)
GetDNSDomain() string
GetEvents(accountID, userID string) ([]*activity.Event, error)
GetDNSSettings(accountID string, userID string) (*DNSSettings, error)
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
GetPeer(accountID, peerID, userID string) (*Peer, error)
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
LoginPeer(login PeerLogin) (*Peer, *NetworkMap, error) //used by peer gRPC API
SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error) //used by peer gRPC API
}
type DefaultAccountManager struct {
Store Store
// cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID
cacheMux sync.Mutex
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
cacheLoading map[string]chan struct{}
peersUpdateManager *PeersUpdateManager
idpManager idp.Manager
cacheManager cache.CacheInterface[[]*idp.UserData]
ctx context.Context
eventStore activity.Store
// singleAccountMode indicates whether the instance has a single account.
// If true, then every new user will end up under the same account.
// This value will be set to false if management service has more than one account.
singleAccountMode bool
// singleAccountModeDomain is a domain to use in singleAccountMode setup
singleAccountModeDomain string
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
peerLoginExpiry Scheduler
}
// Settings represents Account settings structure that can be modified via API and Dashboard
type Settings struct {
// PeerLoginExpirationEnabled globally enables or disables peer login expiration
PeerLoginExpirationEnabled bool
// PeerLoginExpiration is a setting that indicates when peer login expires.
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
PeerLoginExpiration time.Duration
}
// Copy copies the Settings struct
func (s *Settings) Copy() *Settings {
return &Settings{
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
PeerLoginExpiration: s.PeerLoginExpiration,
}
}
// Account represents a unique account of the system
type Account struct {
Id string
// User.Id it was created by
CreatedBy string
Domain string
DomainCategory string
IsDomainPrimaryAccount bool
SetupKeys map[string]*SetupKey
Network *Network
Peers map[string]*Peer
Users map[string]*User
Groups map[string]*Group
Rules map[string]*Rule
Routes map[string]*route.Route
NameServerGroups map[string]*nbdns.NameServerGroup
DNSSettings *DNSSettings
// Settings is a dictionary of Account settings
Settings *Settings
}
type UserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
}
// getRoutesToSync returns the enabled routes for the peer ID and the routes
// from the ACL peers that have distribution groups associated with the peer ID.
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
func (a *Account) getRoutesToSync(peerID string, aclPeers []*Peer) []*route.Route {
routes, peerDisabledRoutes := a.getEnabledAndDisabledRoutesByPeer(peerID)
peerRoutesMembership := make(lookupMap)
for _, r := range append(routes, peerDisabledRoutes...) {
peerRoutesMembership[route.GetHAUniqueID(r)] = struct{}{}
}
groupListMap := a.getPeerGroups(peerID)
for _, peer := range aclPeers {
activeRoutes, _ := a.getEnabledAndDisabledRoutesByPeer(peer.ID)
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap)
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
routes = append(routes, filteredRoutes...)
}
return routes
}
// filterRoutesByHAMembership filters and returns a list of routes that don't share the same HA route membership
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
_, found := peerMemberships[route.GetHAUniqueID(r)]
if !found {
filteredRoutes = append(filteredRoutes, r)
}
}
return filteredRoutes
}
// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map
func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap lookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
for _, groupID := range r.Groups {
_, found := groupListMap[groupID]
if found {
filteredRoutes = append(filteredRoutes, r)
break
}
}
}
return filteredRoutes
}
// getEnabledAndDisabledRoutesByPeer returns the enabled and disabled lists of routes that belong to a peer.
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
func (a *Account) getEnabledAndDisabledRoutesByPeer(peerID string) ([]*route.Route, []*route.Route) {
var enabledRoutes []*route.Route
var disabledRoutes []*route.Route
for _, r := range a.Routes {
if r.Peer == peerID {
// We need to set Peer.Key instead of Peer.ID because this object will be sent to agents as part of a network map.
// Ideally we should have a separate field for that, but fine for now.
peer := a.GetPeer(peerID)
if peer == nil {
log.Errorf("route %s has peer %s that doesn't exist under account %s", r.ID, peerID, a.Id)
continue
}
raut := r.Copy()
raut.Peer = peer.Key
if r.Enabled {
enabledRoutes = append(enabledRoutes, raut)
continue
}
disabledRoutes = append(disabledRoutes, raut)
}
}
return enabledRoutes, disabledRoutes
}
// GetRoutesByPrefix return list of routes by account and route prefix
func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route {
var routes []*route.Route
for _, r := range a.Routes {
if r.Network.String() == prefix.String() {
routes = append(routes, r)
}
}
return routes
}
// GetPeerByIP returns peer by it's IP if exists under account or nil otherwise
func (a *Account) GetPeerByIP(peerIP string) *Peer {
for _, peer := range a.Peers {
if peerIP == peer.IP.String() {
return peer
}
}
return nil
}
// GetPeerRules returns a list of source or destination rules of a given peer.
func (a *Account) GetPeerRules(peerID string) (srcRules []*Rule, dstRules []*Rule) {
// Rules are group based so there is no direct access to peers.
// First, find all groups that the given peer belongs to
peerGroups := make(map[string]struct{})
for s, group := range a.Groups {
for _, peer := range group.Peers {
if peerID == peer {
peerGroups[s] = struct{}{}
break
}
}
}
// Second, find all rules that have discovered source and destination groups
srcRulesMap := make(map[string]*Rule)
dstRulesMap := make(map[string]*Rule)
for _, rule := range a.Rules {
for _, g := range rule.Source {
if _, ok := peerGroups[g]; ok && srcRulesMap[rule.ID] == nil {
srcRules = append(srcRules, rule)
srcRulesMap[rule.ID] = rule
}
}
for _, g := range rule.Destination {
if _, ok := peerGroups[g]; ok && dstRulesMap[rule.ID] == nil {
dstRules = append(dstRules, rule)
dstRulesMap[rule.ID] = rule
}
}
}
return srcRules, dstRules
}
// GetGroup returns a group by ID if exists, nil otherwise
func (a *Account) GetGroup(groupID string) *Group {
return a.Groups[groupID]
}
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise
func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap {
aclPeers := a.getPeersByACL(peerID)
// exclude expired peers
var peersToConnect []*Peer
for _, p := range aclPeers {
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
if expired {
continue
}
peersToConnect = append(peersToConnect, p)
}
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
routesUpdate := a.getRoutesToSync(peerID, peersToConnect)
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
dnsUpdate := nbdns.Config{
ServiceEnable: dnsManagementStatus,
}
if dnsManagementStatus {
var zones []nbdns.CustomZone
peersCustomZone := getPeersCustomZone(a, dnsDomain)
if peersCustomZone.Domain != "" {
zones = append(zones, peersCustomZone)
}
dnsUpdate.CustomZones = zones
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
}
return &NetworkMap{
Peers: peersToConnect,
Network: a.Network.Copy(),
Routes: routesUpdate,
DNSConfig: dnsUpdate,
}
}
// GetExpiredPeers returns peers that have been expired
func (a *Account) GetExpiredPeers() []*Peer {
var peers []*Peer
for _, peer := range a.GetPeersWithExpiration() {
expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration)
if expired {
peers = append(peers, peer)
}
}
return peers
}
// GetNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are connected.
func (a *Account) GetNextPeerExpiration() (time.Duration, bool) {
peersWithExpiry := a.GetPeersWithExpiration()
if len(peersWithExpiry) == 0 {
return 0, false
}
var nextExpiry *time.Duration
for _, peer := range peersWithExpiry {
// consider only connected peers because others will require login on connecting to the management server
if peer.Status.LoginExpired || !peer.Status.Connected {
continue
}
_, duration := peer.LoginExpired(a.Settings.PeerLoginExpiration)
if nextExpiry == nil || duration < *nextExpiry {
nextExpiry = &duration
}
}
if nextExpiry == nil {
return 0, false
}
return *nextExpiry, true
}
// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true
func (a *Account) GetPeersWithExpiration() []*Peer {
peers := make([]*Peer, 0)
for _, peer := range a.Peers {
if peer.LoginExpirationEnabled {
peers = append(peers, peer)
}
}
return peers
}
// GetPeers returns a list of all Account peers
func (a *Account) GetPeers() []*Peer {
var peers []*Peer
for _, peer := range a.Peers {
peers = append(peers, peer)
}
return peers
}
// UpdateSettings saves new account settings
func (a *Account) UpdateSettings(update *Settings) *Account {
a.Settings = update.Copy()
return a
}
// UpdatePeer saves new or replaces existing peer
func (a *Account) UpdatePeer(update *Peer) {
a.Peers[update.ID] = update
}
// DeletePeer deletes peer from the account cleaning up all the references
func (a *Account) DeletePeer(peerID string) {
// delete peer from groups
for _, g := range a.Groups {
for i, pk := range g.Peers {
if pk == peerID {
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
break
}
}
}
for _, r := range a.Routes {
if r.Peer == peerID {
r.Enabled = false
r.Peer = ""
}
}
delete(a.Peers, peerID)
a.Network.IncSerial()
}
// FindPeerByPubKey looks for a Peer by provided WireGuard public key in the Account or returns error if it wasn't found.
// It will return an object copy of the peer.
func (a *Account) FindPeerByPubKey(peerPubKey string) (*Peer, error) {
for _, peer := range a.Peers {
if peer.Key == peerPubKey {
return peer.Copy(), nil
}
}
return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", peerPubKey)
}
// FindUserPeers returns a list of peers that user owns (created)
func (a *Account) FindUserPeers(userID string) ([]*Peer, error) {
peers := make([]*Peer, 0)
for _, peer := range a.Peers {
if peer.UserID == userID {
peers = append(peers, peer)
}
}
return peers, nil
}
// FindUser looks for a given user in the Account or returns error if user wasn't found.
func (a *Account) FindUser(userID string) (*User, error) {
user := a.Users[userID]
if user == nil {
return nil, status.Errorf(status.NotFound, "user %s not found", userID)
}
return user, nil
}
// FindSetupKey looks for a given SetupKey in the Account or returns error if it wasn't found.
func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) {
key := a.SetupKeys[setupKey]
if key == nil {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return key, nil
}
func (a *Account) getUserGroups(userID string) ([]string, error) {
user, err := a.FindUser(userID)
if err != nil {
return nil, err
}
return user.AutoGroups, nil
}
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := a.getPeerGroups(peerID)
enabled := true
if a.DNSSettings != nil {
for _, groupID := range a.DNSSettings.DisabledManagementGroups {
_, found := peerGroups[groupID]
if found {
enabled = false
break
}
}
}
return enabled
}
func (a *Account) getPeerGroups(peerID string) lookupMap {
groupList := make(lookupMap)
for groupID, group := range a.Groups {
for _, id := range group.Peers {
if id == peerID {
groupList[groupID] = struct{}{}
break
}
}
}
return groupList
}
func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) {
key, err := a.FindSetupKey(setupKey)
if err != nil {
return nil, err
}
return key.AutoGroups, nil
}
func (a *Account) getTakenIPs() []net.IP {
var takenIps []net.IP
for _, existingPeer := range a.Peers {
takenIps = append(takenIps, existingPeer.IP)
}
return takenIps
}
func (a *Account) getPeerDNSLabels() lookupMap {
existingLabels := make(lookupMap)
for _, peer := range a.Peers {
if peer.DNSLabel != "" {
existingLabels[peer.DNSLabel] = struct{}{}
}
}
return existingLabels
}
func (a *Account) Copy() *Account {
peers := map[string]*Peer{}
for id, peer := range a.Peers {
peers[id] = peer.Copy()
}
users := map[string]*User{}
for id, user := range a.Users {
users[id] = user.Copy()
}
setupKeys := map[string]*SetupKey{}
for id, key := range a.SetupKeys {
setupKeys[id] = key.Copy()
}
groups := map[string]*Group{}
for id, group := range a.Groups {
groups[id] = group.Copy()
}
rules := map[string]*Rule{}
for id, rule := range a.Rules {
rules[id] = rule.Copy()
}
routes := map[string]*route.Route{}
for id, route := range a.Routes {
routes[id] = route.Copy()
}
nsGroups := map[string]*nbdns.NameServerGroup{}
for id, nsGroup := range a.NameServerGroups {
nsGroups[id] = nsGroup.Copy()
}
var dnsSettings *DNSSettings
if a.DNSSettings != nil {
dnsSettings = a.DNSSettings.Copy()
}
var settings *Settings
if a.Settings != nil {
settings = a.Settings.Copy()
}
return &Account{
Id: a.Id,
CreatedBy: a.CreatedBy,
Domain: a.Domain,
DomainCategory: a.DomainCategory,
IsDomainPrimaryAccount: a.IsDomainPrimaryAccount,
SetupKeys: setupKeys,
Network: a.Network.Copy(),
Peers: peers,
Users: users,
Groups: groups,
Rules: rules,
Routes: routes,
NameServerGroups: nsGroups,
DNSSettings: dnsSettings,
Settings: settings,
}
}
func (a *Account) GetGroupAll() (*Group, error) {
for _, g := range a.Groups {
if g.Name == "All" {
return g, nil
}
}
return nil, fmt.Errorf("no group ALL found")
}
// GetPeer looks up a Peer by ID
func (a *Account) GetPeer(peerID string) *Peer {
return a.Peers[peerID]
}
// BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store,
) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{
Store: store,
peersUpdateManager: peersUpdateManager,
idpManager: idpManager,
ctx: context.Background(),
cacheMux: sync.Mutex{},
cacheLoading: map[string]chan struct{}{},
dnsDomain: dnsDomain,
eventStore: eventStore,
peerLoginExpiry: NewDefaultScheduler(),
}
allAccounts := store.GetAllAccounts()
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1
if am.singleAccountMode {
if !isDomainValid(singleAccountModeDomain) {
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain)
}
am.singleAccountModeDomain = singleAccountModeDomain
log.Infof("single account mode enabled, accounts number %d", len(allAccounts))
} else {
log.Infof("single account mode disabled, accounts number %d", len(allAccounts))
}
// if account doesn't have a default group
// we create 'all' group and add all peers into it
// also we create default rule with source as destination
for _, account := range allAccounts {
shouldSave := false
_, err := account.GetGroupAll()
if err != nil {
addAllGroup(account)
shouldSave = true
}
if shouldSave {
err = store.SaveAccount(account)
if err != nil {
return nil, err
}
}
}
goCacheClient := gocache.New(CacheExpirationMax, 30*time.Minute)
goCacheStore := cacheStore.NewGoCache(goCacheClient)
am.cacheManager = cache.NewLoadable[[]*idp.UserData](am.loadAccount, cache.New[[]*idp.UserData](goCacheStore))
if !isNil(am.idpManager) {
go func() {
err := am.warmupIDPCache()
if err != nil {
log.Warnf("failed warming up cache due to error: %v", err)
// todo retry?
return
}
}()
}
return am, nil
}
// UpdateAccountSettings updates Account settings.
// Only users with role UserRoleAdmin can update the account.
// User that performs the update has to belong to the account.
// Returns an updated Account
func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error) {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
}
if newSettings.PeerLoginExpiration < time.Hour {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
}
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccountByUser(userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.IsAdmin() {
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account")
}
oldSettings := account.Settings
if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled {
event := activity.AccountPeerLoginExpirationEnabled
if !newSettings.PeerLoginExpirationEnabled {
event = activity.AccountPeerLoginExpirationDisabled
am.peerLoginExpiry.Cancel([]string{accountID})
} else {
am.checkAndSchedulePeerLoginExpiration(account)
}
am.storeEvent(userID, accountID, accountID, event, nil)
}
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
am.storeEvent(userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
am.checkAndSchedulePeerLoginExpiration(account)
}
updatedAccount := account.UpdateSettings(newSettings)
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
return updatedAccount, nil
}
func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
log.Errorf("failed getting account %s expiring peers", account.Id)
return account.GetNextPeerExpiration()
}
var peerIDs []string
for _, peer := range account.GetExpiredPeers() {
if peer.Status.LoginExpired {
continue
}
peerIDs = append(peerIDs, peer.ID)
peer.MarkLoginExpired(true)
account.UpdatePeer(peer)
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
if err != nil {
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
return account.GetNextPeerExpiration()
}
}
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service
am.peersUpdateManager.CloseChannels(peerIDs)
err := am.updateAccountPeers(account)
if err != nil {
log.Errorf("failed updating account peers while expiring peers for account %s", accountID)
return account.GetNextPeerExpiration()
}
}
return account.GetNextPeerExpiration()
}
}
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(account *Account) {
am.peerLoginExpiry.Cancel([]string{account.Id})
if nextRun, ok := account.GetNextPeerExpiration(); ok {
go am.peerLoginExpiry.Schedule(nextRun, account.Id, am.peerLoginExpirationJob(account.Id))
}
}
// newAccount creates a new Account with a generated ID and generated default setup keys.
// If ID is already in use (due to collision) we try one more time before returning error
func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, error) {
for i := 0; i < 2; i++ {
accountId := xid.New().String()
_, err := am.Store.GetAccount(accountId)
statusErr, _ := status.FromError(err)
if err == nil {
log.Warnf("an account with ID already exists, retrying...")
continue
} else if statusErr.Type() == status.NotFound {
newAccount := newAccountWithId(accountId, userID, domain)
am.storeEvent(userID, newAccount.Id, accountId, activity.AccountCreated, nil)
return newAccount, nil
} else {
return nil, err
}
}
return nil, status.Errorf(status.Internal, "error while creating new account")
}
func (am *DefaultAccountManager) warmupIDPCache() error {
userData, err := am.idpManager.GetAllAccounts()
if err != nil {
return err
}
for accountID, users := range userData {
err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration()))
if err != nil {
return err
}
}
log.Infof("warmed up IDP cache with %d entries", len(userData))
return nil
}
// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
// userID doesn't have an account associated with it, one account is created
func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) {
if accountID != "" {
return am.Store.GetAccount(accountID)
} else if userID != "" {
account, err := am.GetOrCreateAccountByUser(userID, domain)
if err != nil {
return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID)
}
err = am.addAccountIDToIDPAppMeta(userID, account)
if err != nil {
return nil, err
}
return account, nil
}
return nil, status.Errorf(status.NotFound, "no valid user or account Id provided")
}
func isNil(i idp.Manager) bool {
return i == nil || reflect.ValueOf(i).IsNil()
}
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account *Account) error {
if !isNil(am.idpManager) {
// user can be nil if it wasn't found (e.g., just created)
user, err := am.lookupUserInCache(userID, account)
if err != nil {
return err
}
if user != nil && user.AppMetadata.WTAccountID == account.Id {
// it was already set, so we skip the unnecessary update
log.Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
account.Id, userID)
return nil
}
err = am.idpManager.UpdateUserAppMetadata(userID, idp.AppMetadata{WTAccountID: account.Id})
if err != nil {
return err
}
if err != nil {
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
}
// refresh cache to reflect the update
_, err = am.refreshCache(account.Id)
if err != nil {
return err
}
}
return nil
}
func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interface{}) ([]*idp.UserData, error) {
log.Debugf("account %s not found in cache, reloading", accountID)
return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID))
}
func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountID string) (*idp.UserData, error) {
data, err := am.getAccountFromCache(accountID, false)
if err != nil {
return nil, err
}
for _, datum := range data {
if datum.Email == email {
return datum, nil
}
}
return nil, nil
}
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
users := make(map[string]struct{}, len(account.Users))
for _, user := range account.Users {
users[user.Id] = struct{}{}
}
log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
userData, err := am.lookupCache(users, account.Id)
if err != nil {
return nil, err
}
for _, datum := range userData {
if datum.ID == userID {
return datum, nil
}
}
return nil, nil
}
func (am *DefaultAccountManager) refreshCache(accountID string) ([]*idp.UserData, error) {
return am.getAccountFromCache(accountID, true)
}
// getAccountFromCache returns user data for a given account ensuring that cache load happens only once
func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceReload bool) ([]*idp.UserData, error) {
am.cacheMux.Lock()
loadingChan := am.cacheLoading[accountID]
if loadingChan == nil {
loadingChan = make(chan struct{})
am.cacheLoading[accountID] = loadingChan
am.cacheMux.Unlock()
defer func() {
am.cacheMux.Lock()
delete(am.cacheLoading, accountID)
close(loadingChan)
am.cacheMux.Unlock()
}()
if forceReload {
err := am.cacheManager.Delete(am.ctx, accountID)
if err != nil {
return nil, err
}
}
return am.cacheManager.Get(am.ctx, accountID)
}
am.cacheMux.Unlock()
log.Debugf("one request to get account %s is already running", accountID)
select {
case <-loadingChan:
// channel has been closed meaning cache was loaded => simply return from cache
return am.cacheManager.Get(am.ctx, accountID)
case <-time.After(5 * time.Second):
return nil, fmt.Errorf("timeout while waiting for account %s cache to reload", accountID)
}
}
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, accountID string) ([]*idp.UserData, error) {
data, err := am.getAccountFromCache(accountID, false)
if err != nil {
return nil, err
}
userDataMap := make(map[string]struct{})
for _, datum := range data {
userDataMap[datum.ID] = struct{}{}
}
// check whether we need to reload the cache
// the accountUsers ID list is the source of truth and all the users should be in the cache
reload := len(accountUsers) != len(data)
for user := range accountUsers {
if _, ok := userDataMap[user]; !ok {
reload = true
}
}
if reload {
// reload cache once avoiding loops
data, err = am.refreshCache(accountID)
if err != nil {
return nil, err
}
}
return data, err
}
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims,
primaryDomain bool,
) error {
account.IsDomainPrimaryAccount = primaryDomain
lowerDomain := strings.ToLower(claims.Domain)
userObj := account.Users[claims.UserId]
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
account.Domain = lowerDomain
}
// prevent updating category for different domain until admin logs in
if account.Domain == lowerDomain {
account.DomainCategory = claims.DomainCategory
}
err := am.Store.SaveAccount(account)
if err != nil {
return err
}
return nil
}
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
//
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
// was previously unclassified or classified as public so N users that logged int that time, has they own account
// and peers that shouldn't be lost.
func (am *DefaultAccountManager) handleExistingUserAccount(
existingAcc *Account,
domainAcc *Account,
claims jwtclaims.AuthorizationClaims,
) error {
var err error
if domainAcc != nil && existingAcc.Id != domainAcc.Id {
err = am.updateAccountDomainAttributes(existingAcc, claims, false)
if err != nil {
return err
}
} else {
err = am.updateAccountDomainAttributes(existingAcc, claims, true)
if err != nil {
return err
}
}
// we should register the account ID to this user's metadata in our IDP manager
err = am.addAccountIDToIDPAppMeta(claims.UserId, existingAcc)
if err != nil {
return err
}
return nil
}
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
// otherwise it will create a new account and make it primary account for the domain.
func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
if claims.UserId == "" {
return nil, fmt.Errorf("user ID is empty")
}
var (
account *Account
err error
)
lowerDomain := strings.ToLower(claims.Domain)
// if domain already has a primary account, add regular user
if domainAcc != nil {
account = domainAcc
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
} else {
account, err = am.newAccount(claims.UserId, lowerDomain)
if err != nil {
return nil, err
}
err = am.updateAccountDomainAttributes(account, claims, true)
if err != nil {
return nil, err
}
}
err = am.addAccountIDToIDPAppMeta(claims.UserId, account)
if err != nil {
return nil, err
}
am.storeEvent(claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil)
return account, nil
}
// redeemInvite checks whether user has been invited and redeems the invite
func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) error {
// only possible with the enabled IdP manager
if am.idpManager == nil {
log.Warnf("invites only work with enabled IdP manager")
return nil
}
user, err := am.lookupUserInCache(userID, account)
if err != nil {
return err
}
if user == nil {
return status.Errorf(status.NotFound, "user %s not found in the IdP", userID)
}
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
log.Infof("redeeming invite for user %s account %s", userID, account.Id)
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
// Our job is to just reload cache.
go func() {
_, err = am.refreshCache(account.Id)
if err != nil {
log.Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id)
return
}
log.Debugf("user %s of account %s redeemed invite", user.ID, account.Id)
am.storeEvent(userID, userID, account.Id, activity.UserJoined, nil)
}()
}
return nil
}
// GetAccountFromToken returns an account associated with this token
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
if claims.UserId == "" {
return nil, nil, fmt.Errorf("user ID is empty")
}
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
// We override incoming domain claims to group users under a single account.
claims.Domain = am.singleAccountModeDomain
claims.DomainCategory = PrivateCategory
log.Infof("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
}
account, err := am.getAccountWithAuthorizationClaims(claims)
if err != nil {
return nil, nil, err
}
user := account.Users[claims.UserId]
if user == nil {
// this is not really possible because we got an account by user ID
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
}
err = am.redeemInvite(account, claims.UserId)
if err != nil {
return nil, nil, err
}
return account, user, nil
}
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
// if domain is of the PrivateCategory category, it will evaluate
// if account is new, existing or if there is another account with the same domain
//
// Use cases:
//
// New user + New account + New domain -> create account, user role = admin (if private domain, index domain)
//
// New user + New account + Existing Private Domain -> add user to the existing account, user role = regular (not admin)
//
// New user + New account + Existing Public Domain -> create account, user role = admin
//
// Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain)
//
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
//
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) {
if claims.UserId == "" {
return nil, fmt.Errorf("user ID is empty")
}
// if Account ID is part of the claims
// it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountByUserOrAccountID(claims.UserId, claims.AccountId, claims.Domain)
} else if claims.AccountId != "" {
accountFromID, err := am.Store.GetAccount(claims.AccountId)
if err != nil {
return nil, err
}
if _, ok := accountFromID.Users[claims.UserId]; !ok {
return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
}
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory {
return accountFromID, nil
}
}
unlock := am.Store.AcquireGlobalLock()
defer unlock()
// We checked if the domain has a primary account already
domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain)
if err != nil {
// if NotFound we are good to continue, otherwise return error
e, ok := status.FromError(err)
if !ok || e.Type() != status.NotFound {
return nil, err
}
}
account, err := am.Store.GetAccountByUser(claims.UserId)
if err == nil {
err = am.handleExistingUserAccount(account, domainAccount, claims)
if err != nil {
return nil, err
}
return account, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
return am.handleNewUserAccount(domainAccount, claims)
} else {
// other error
return nil, err
}
}
func isDomainValid(domain string) bool {
re := regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
return re.Match([]byte(domain))
}
// AccountExists checks whether account exists (returns true) or not (returns false)
func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
var res bool
_, err := am.Store.GetAccount(accountID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
res = false
return &res, nil
} else {
return nil, err
}
}
res = true
return &res, nil
}
// GetDNSDomain returns the configured dnsDomain
func (am *DefaultAccountManager) GetDNSDomain() string {
return am.dnsDomain
}
// addAllGroup to account object if it doesn't exists
func addAllGroup(account *Account) {
if len(account.Groups) == 0 {
allGroup := &Group{
ID: xid.New().String(),
Name: "All",
}
for _, peer := range account.Peers {
allGroup.Peers = append(allGroup.Peers, peer.ID)
}
account.Groups = map[string]*Group{allGroup.ID: allGroup}
defaultRule := &Rule{
ID: xid.New().String(),
Name: DefaultRuleName,
Description: DefaultRuleDescription,
Disabled: false,
Source: []string{allGroup.ID},
Destination: []string{allGroup.ID},
}
account.Rules = map[string]*Rule{defaultRule.ID: defaultRule}
}
}
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
func newAccountWithId(accountId, userId, domain string) *Account {
log.Debugf("creating new account")
setupKeys := make(map[string]*SetupKey)
defaultKey := GenerateDefaultSetupKey()
oneOffKey := GenerateSetupKey("One-off key", SetupKeyOneOff, DefaultSetupKeyDuration, []string{},
SetupKeyUnlimitedUsage)
setupKeys[defaultKey.Key] = defaultKey
setupKeys[oneOffKey.Key] = oneOffKey
network := NewNetwork()
peers := make(map[string]*Peer)
users := make(map[string]*User)
routes := make(map[string]*route.Route)
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userId] = NewAdminUser(userId)
dnsSettings := &DNSSettings{
DisabledManagementGroups: make([]string, 0),
}
log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key)
acc := &Account{
Id: accountId,
SetupKeys: setupKeys,
Network: network,
Peers: peers,
Users: users,
CreatedBy: userId,
Domain: domain,
Routes: routes,
NameServerGroups: nameServersGroups,
DNSSettings: dnsSettings,
Settings: &Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: DefaultPeerLoginExpiration,
},
}
addAllGroup(acc)
return acc
}
func removeFromList(inputList []string, toRemove []string) []string {
toRemoveMap := make(map[string]struct{})
for _, item := range toRemove {
toRemoveMap[item] = struct{}{}
}
var resultList []string
for _, item := range inputList {
_, ok := toRemoveMap[item]
if !ok {
resultList = append(resultList, item)
}
}
return resultList
}