netbird/management/server/account.go
Yury Gargay 32880c56a4
Implement SQLite Store using gorm and relational approach (#1065)
Restructure data handling for improved performance and flexibility. 
Introduce 'G'-prefixed fields to represent Gorm relations, simplifying resource management. 
Eliminate complexity in lookup tables for enhanced query and write speed. 
Enable independent operations on data structures, requiring adjustments in the Store interface and Account Manager.
2023-10-12 15:42:36 +02:00

1652 lines
53 KiB
Go

package server
import (
"context"
"crypto/sha256"
b64 "encoding/base64"
"fmt"
"hash/crc32"
"math/rand"
"net"
"net/netip"
"reflect"
"regexp"
"strings"
"sync"
"time"
"github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
)
const (
PublicCategory = "public"
PrivateCategory = "private"
UnknownCategory = "unknown"
GroupIssuedAPI = "api"
GroupIssuedJWT = "jwt"
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
DefaultPeerLoginExpiration = 24 * time.Hour
)
func cacheEntryExpiration() time.Duration {
r := rand.Intn(int(CacheExpirationMax.Milliseconds()-CacheExpirationMin.Milliseconds())) + int(CacheExpirationMin.Milliseconds())
return time.Duration(r) * time.Millisecond
}
type AccountManager interface {
GetOrCreateAccountByUser(userId, domain string) (*Account, error)
CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error)
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
DeleteUser(accountID, initiatorUserID string, targetUserID string) error
InviteUser(accountID string, initiatorUserID string, targetUserID string) error
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error)
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
MarkPATUsed(tokenID string) error
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
GetPeers(accountID, userID string) ([]*Peer, error)
MarkPeerConnected(peerKey string, connected bool) error
DeletePeer(accountID, peerID, userID string) error
UpdatePeer(accountID, userID string, peer *Peer) (*Peer, error)
GetNetworkMap(peerID string) (*NetworkMap, error)
GetPeerNetwork(peerID string) (*Network, error)
AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error)
CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
UpdatePeerSSHKey(peerID string, sshKey string) error
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
GetGroup(accountId, groupID string) (*Group, error)
SaveGroup(accountID, userID string, group *Group) error
DeleteGroup(accountId, userId, groupID string) error
ListGroups(accountId string) ([]*Group, error)
GroupAddPeer(accountId, groupID, peerID string) error
GroupDeletePeer(accountId, groupID, peerID string) error
GetPolicy(accountID, policyID, userID string) (*Policy, error)
SavePolicy(accountID, userID string, policy *Policy) error
DeletePolicy(accountID, policyID, userID string) error
ListPolicies(accountID, userID string) ([]*Policy, error)
GetRoute(accountID, routeID, userID string) (*route.Route, error)
CreateRoute(accountID, prefix, peerID string, peerGroupIDs []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
SaveRoute(accountID, userID string, route *route.Route) error
DeleteRoute(accountID, routeID, userID string) error
ListRoutes(accountID, userID string) ([]*route.Route, error)
GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroup(accountID, nsGroupID, userID string) error
ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error)
GetDNSDomain() string
GetEvents(accountID, userID string) ([]*activity.Event, error)
GetDNSSettings(accountID string, userID string) (*DNSSettings, error)
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
GetPeer(accountID, peerID, userID string) (*Peer, error)
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
LoginPeer(login PeerLogin) (*Peer, *NetworkMap, error) // used by peer gRPC API
SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error)
}
type DefaultAccountManager struct {
Store Store
// cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID
cacheMux sync.Mutex
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
cacheLoading map[string]chan struct{}
peersUpdateManager *PeersUpdateManager
idpManager idp.Manager
cacheManager cache.CacheInterface[[]*idp.UserData]
ctx context.Context
eventStore activity.Store
// singleAccountMode indicates whether the instance has a single account.
// If true, then every new user will end up under the same account.
// This value will be set to false if management service has more than one account.
singleAccountMode bool
// singleAccountModeDomain is a domain to use in singleAccountMode setup
singleAccountModeDomain string
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
peerLoginExpiry Scheduler
// userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account
userDeleteFromIDPEnabled bool
}
// Settings represents Account settings structure that can be modified via API and Dashboard
type Settings struct {
// PeerLoginExpirationEnabled globally enables or disables peer login expiration
PeerLoginExpirationEnabled bool
// PeerLoginExpiration is a setting that indicates when peer login expires.
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
PeerLoginExpiration time.Duration
// GroupsPropagationEnabled allows to propagate auto groups from the user to the peer
GroupsPropagationEnabled bool
// JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName
// and add it to account groups.
JWTGroupsEnabled bool
// JWTGroupsClaimName from which we extract groups name to add it to account groups
JWTGroupsClaimName string
}
// Copy copies the Settings struct
func (s *Settings) Copy() *Settings {
return &Settings{
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
PeerLoginExpiration: s.PeerLoginExpiration,
JWTGroupsEnabled: s.JWTGroupsEnabled,
JWTGroupsClaimName: s.JWTGroupsClaimName,
GroupsPropagationEnabled: s.GroupsPropagationEnabled,
}
}
// Account represents a unique account of the system
type Account struct {
// we have to name column to aid as it collides with Network.Id when work with associations
Id string `gorm:"primaryKey"`
// User.Id it was created by
CreatedBy string
Domain string `gorm:"index"`
DomainCategory string
IsDomainPrimaryAccount bool
SetupKeys map[string]*SetupKey `gorm:"-"`
SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"`
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
Peers map[string]*Peer `gorm:"-"`
PeersG []Peer `json:"-" gorm:"foreignKey:AccountID;references:id"`
Users map[string]*User `gorm:"-"`
UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"`
Groups map[string]*Group `gorm:"-"`
GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
Rules map[string]*Rule `gorm:"-"`
RulesG []Rule `json:"-" gorm:"foreignKey:AccountID;references:id"`
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
Routes map[string]*route.Route `gorm:"-"`
RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`
NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"`
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
// Settings is a dictionary of Account settings
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
}
type UserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"`
LastLogin time.Time `json:"last_login"`
}
// getRoutesToSync returns the enabled routes for the peer ID and the routes
// from the ACL peers that have distribution groups associated with the peer ID.
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
func (a *Account) getRoutesToSync(peerID string, aclPeers []*Peer) []*route.Route {
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID)
peerRoutesMembership := make(lookupMap)
for _, r := range append(routes, peerDisabledRoutes...) {
peerRoutesMembership[route.GetHAUniqueID(r)] = struct{}{}
}
groupListMap := a.getPeerGroups(peerID)
for _, peer := range aclPeers {
activeRoutes, _ := a.getRoutingPeerRoutes(peer.ID)
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap)
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
routes = append(routes, filteredRoutes...)
}
return routes
}
// filterRoutesByHAMembership filters and returns a list of routes that don't share the same HA route membership
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
_, found := peerMemberships[route.GetHAUniqueID(r)]
if !found {
filteredRoutes = append(filteredRoutes, r)
}
}
return filteredRoutes
}
// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map
func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap lookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
for _, groupID := range r.Groups {
_, found := groupListMap[groupID]
if found {
filteredRoutes = append(filteredRoutes, r)
break
}
}
}
return filteredRoutes
}
// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
// If the given is not a routing peer, then the lists are empty.
func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) {
peer := a.GetPeer(peerID)
if peer == nil {
log.Errorf("peer %s that doesn't exist under account %s", peerID, a.Id)
return enabledRoutes, disabledRoutes
}
// currently we support only linux routing peers
if peer.Meta.GoOS != "linux" {
return enabledRoutes, disabledRoutes
}
seenRoute := make(map[string]struct{})
takeRoute := func(r *route.Route, id string) {
if _, ok := seenRoute[r.ID]; ok {
return
}
seenRoute[r.ID] = struct{}{}
if r.Enabled {
r.Peer = peer.Key
enabledRoutes = append(enabledRoutes, r)
return
}
disabledRoutes = append(disabledRoutes, r)
}
for _, r := range a.Routes {
for _, groupID := range r.PeerGroups {
group := a.GetGroup(groupID)
if group == nil {
log.Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id)
continue
}
for _, id := range group.Peers {
if id != peerID {
continue
}
newPeerRoute := r.Copy()
newPeerRoute.Peer = id
newPeerRoute.PeerGroups = nil
newPeerRoute.ID = r.ID + ":" + id // we have to provide unique route id when distribute network map
takeRoute(newPeerRoute, id)
break
}
}
if r.Peer == peerID {
takeRoute(r.Copy(), peerID)
}
}
return enabledRoutes, disabledRoutes
}
// GetRoutesByPrefix return list of routes by account and route prefix
func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route {
var routes []*route.Route
for _, r := range a.Routes {
if r.Network.String() == prefix.String() {
routes = append(routes, r)
}
}
return routes
}
// GetGroup returns a group by ID if exists, nil otherwise
func (a *Account) GetGroup(groupID string) *Group {
return a.Groups[groupID]
}
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise
func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap {
aclPeers, firewallRules := a.getPeerConnectionResources(peerID)
// exclude expired peers
var peersToConnect []*Peer
var expiredPeers []*Peer
for _, p := range aclPeers {
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
if a.Settings.PeerLoginExpirationEnabled && expired {
expiredPeers = append(expiredPeers, p)
continue
}
peersToConnect = append(peersToConnect, p)
}
routesUpdate := a.getRoutesToSync(peerID, peersToConnect)
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
dnsUpdate := nbdns.Config{
ServiceEnable: dnsManagementStatus,
}
if dnsManagementStatus {
var zones []nbdns.CustomZone
peersCustomZone := getPeersCustomZone(a, dnsDomain)
if peersCustomZone.Domain != "" {
zones = append(zones, peersCustomZone)
}
dnsUpdate.CustomZones = zones
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
}
return &NetworkMap{
Peers: peersToConnect,
Network: a.Network.Copy(),
Routes: routesUpdate,
DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
}
}
// GetExpiredPeers returns peers that have been expired
func (a *Account) GetExpiredPeers() []*Peer {
var peers []*Peer
for _, peer := range a.GetPeersWithExpiration() {
expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration)
if expired {
peers = append(peers, peer)
}
}
return peers
}
// GetNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are connected.
func (a *Account) GetNextPeerExpiration() (time.Duration, bool) {
peersWithExpiry := a.GetPeersWithExpiration()
if len(peersWithExpiry) == 0 {
return 0, false
}
var nextExpiry *time.Duration
for _, peer := range peersWithExpiry {
// consider only connected peers because others will require login on connecting to the management server
if peer.Status.LoginExpired || !peer.Status.Connected {
continue
}
_, duration := peer.LoginExpired(a.Settings.PeerLoginExpiration)
if nextExpiry == nil || duration < *nextExpiry {
nextExpiry = &duration
}
}
if nextExpiry == nil {
return 0, false
}
return *nextExpiry, true
}
// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user
func (a *Account) GetPeersWithExpiration() []*Peer {
peers := make([]*Peer, 0)
for _, peer := range a.Peers {
if peer.LoginExpirationEnabled && peer.AddedWithSSOLogin() {
peers = append(peers, peer)
}
}
return peers
}
// GetPeers returns a list of all Account peers
func (a *Account) GetPeers() []*Peer {
var peers []*Peer
for _, peer := range a.Peers {
peers = append(peers, peer)
}
return peers
}
// UpdateSettings saves new account settings
func (a *Account) UpdateSettings(update *Settings) *Account {
a.Settings = update.Copy()
return a
}
// UpdatePeer saves new or replaces existing peer
func (a *Account) UpdatePeer(update *Peer) {
a.Peers[update.ID] = update
}
// DeletePeer deletes peer from the account cleaning up all the references
func (a *Account) DeletePeer(peerID string) {
// delete peer from groups
for _, g := range a.Groups {
for i, pk := range g.Peers {
if pk == peerID {
g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
break
}
}
}
for _, r := range a.Routes {
if r.Peer == peerID {
r.Enabled = false
r.Peer = ""
}
}
delete(a.Peers, peerID)
a.Network.IncSerial()
}
// FindPeerByPubKey looks for a Peer by provided WireGuard public key in the Account or returns error if it wasn't found.
// It will return an object copy of the peer.
func (a *Account) FindPeerByPubKey(peerPubKey string) (*Peer, error) {
for _, peer := range a.Peers {
if peer.Key == peerPubKey {
return peer.Copy(), nil
}
}
return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", peerPubKey)
}
// FindUserPeers returns a list of peers that user owns (created)
func (a *Account) FindUserPeers(userID string) ([]*Peer, error) {
peers := make([]*Peer, 0)
for _, peer := range a.Peers {
if peer.UserID == userID {
peers = append(peers, peer)
}
}
return peers, nil
}
// FindUser looks for a given user in the Account or returns error if user wasn't found.
func (a *Account) FindUser(userID string) (*User, error) {
user := a.Users[userID]
if user == nil {
return nil, status.Errorf(status.NotFound, "user %s not found", userID)
}
return user, nil
}
// FindSetupKey looks for a given SetupKey in the Account or returns error if it wasn't found.
func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) {
key := a.SetupKeys[setupKey]
if key == nil {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
return key, nil
}
func (a *Account) getUserGroups(userID string) ([]string, error) {
user, err := a.FindUser(userID)
if err != nil {
return nil, err
}
return user.AutoGroups, nil
}
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := a.getPeerGroups(peerID)
enabled := true
for _, groupID := range a.DNSSettings.DisabledManagementGroups {
_, found := peerGroups[groupID]
if found {
enabled = false
break
}
}
return enabled
}
func (a *Account) getPeerGroups(peerID string) lookupMap {
groupList := make(lookupMap)
for groupID, group := range a.Groups {
for _, id := range group.Peers {
if id == peerID {
groupList[groupID] = struct{}{}
break
}
}
}
return groupList
}
func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) {
key, err := a.FindSetupKey(setupKey)
if err != nil {
return nil, err
}
return key.AutoGroups, nil
}
func (a *Account) getTakenIPs() []net.IP {
var takenIps []net.IP
for _, existingPeer := range a.Peers {
takenIps = append(takenIps, existingPeer.IP)
}
return takenIps
}
func (a *Account) getPeerDNSLabels() lookupMap {
existingLabels := make(lookupMap)
for _, peer := range a.Peers {
if peer.DNSLabel != "" {
existingLabels[peer.DNSLabel] = struct{}{}
}
}
return existingLabels
}
func (a *Account) Copy() *Account {
peers := map[string]*Peer{}
for id, peer := range a.Peers {
peers[id] = peer.Copy()
}
users := map[string]*User{}
for id, user := range a.Users {
users[id] = user.Copy()
}
setupKeys := map[string]*SetupKey{}
for id, key := range a.SetupKeys {
setupKeys[id] = key.Copy()
}
groups := map[string]*Group{}
for id, group := range a.Groups {
groups[id] = group.Copy()
}
rules := map[string]*Rule{}
for id, rule := range a.Rules {
rules[id] = rule.Copy()
}
policies := []*Policy{}
for _, policy := range a.Policies {
policies = append(policies, policy.Copy())
}
routes := map[string]*route.Route{}
for id, r := range a.Routes {
routes[id] = r.Copy()
}
nsGroups := map[string]*nbdns.NameServerGroup{}
for id, nsGroup := range a.NameServerGroups {
nsGroups[id] = nsGroup.Copy()
}
dnsSettings := a.DNSSettings.Copy()
var settings *Settings
if a.Settings != nil {
settings = a.Settings.Copy()
}
return &Account{
Id: a.Id,
CreatedBy: a.CreatedBy,
Domain: a.Domain,
DomainCategory: a.DomainCategory,
IsDomainPrimaryAccount: a.IsDomainPrimaryAccount,
SetupKeys: setupKeys,
Network: a.Network.Copy(),
Peers: peers,
Users: users,
Groups: groups,
Rules: rules,
Policies: policies,
Routes: routes,
NameServerGroups: nsGroups,
DNSSettings: dnsSettings,
Settings: settings,
}
}
func (a *Account) GetGroupAll() (*Group, error) {
for _, g := range a.Groups {
if g.Name == "All" {
return g, nil
}
}
return nil, fmt.Errorf("no group ALL found")
}
// GetPeer looks up a Peer by ID
func (a *Account) GetPeer(peerID string) *Peer {
return a.Peers[peerID]
}
// SetJWTGroups to account and to user autoassigned groups
func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool {
user, ok := a.Users[userID]
if !ok {
return false
}
existedGroupsByName := make(map[string]*Group)
for _, group := range a.Groups {
existedGroupsByName[group.Name] = group
}
// remove JWT groups from the autogroups, to sync them again
removed := 0
jwtAutoGroups := make(map[string]struct{})
for i, id := range user.AutoGroups {
if group, ok := a.Groups[id]; ok && group.Issued == GroupIssuedJWT {
jwtAutoGroups[group.Name] = struct{}{}
user.AutoGroups = append(user.AutoGroups[:i-removed], user.AutoGroups[i-removed+1:]...)
removed++
}
}
// create JWT groups if they doesn't exist
// and all of them to the autogroups
var modified bool
for _, name := range groupsNames {
group, ok := existedGroupsByName[name]
if !ok {
group = &Group{
ID: xid.New().String(),
Name: name,
Issued: GroupIssuedJWT,
}
a.Groups[group.ID] = group
}
// only JWT groups will be synced
if group.Issued == GroupIssuedJWT {
user.AutoGroups = append(user.AutoGroups, group.ID)
if _, ok := jwtAutoGroups[name]; !ok {
modified = true
}
delete(jwtAutoGroups, name)
}
}
// if not empty it means we removed some groups
if len(jwtAutoGroups) > 0 {
modified = true
}
return modified
}
// UserGroupsAddToPeers adds groups to all peers of user
func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
userPeers := make(map[string]struct{})
for pid, peer := range a.Peers {
if peer.UserID == userID {
userPeers[pid] = struct{}{}
}
}
for _, gid := range groups {
group, ok := a.Groups[gid]
if !ok {
continue
}
groupPeers := make(map[string]struct{})
for _, pid := range group.Peers {
groupPeers[pid] = struct{}{}
}
for pid := range userPeers {
groupPeers[pid] = struct{}{}
}
group.Peers = group.Peers[:0]
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
}
}
// UserGroupsRemoveFromPeers removes groups from all peers of user
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
for _, gid := range groups {
group, ok := a.Groups[gid]
if !ok {
continue
}
update := make([]string, 0, len(group.Peers))
for _, pid := range group.Peers {
peer, ok := a.Peers[pid]
if !ok {
continue
}
if peer.UserID != userID {
update = append(update, pid)
}
}
group.Peers = update
}
}
// BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, userDeleteFromIDPEnabled bool,
) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{
Store: store,
peersUpdateManager: peersUpdateManager,
idpManager: idpManager,
ctx: context.Background(),
cacheMux: sync.Mutex{},
cacheLoading: map[string]chan struct{}{},
dnsDomain: dnsDomain,
eventStore: eventStore,
peerLoginExpiry: NewDefaultScheduler(),
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
}
allAccounts := store.GetAllAccounts()
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1
if am.singleAccountMode {
if !isDomainValid(singleAccountModeDomain) {
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain)
}
am.singleAccountModeDomain = singleAccountModeDomain
log.Infof("single account mode enabled, accounts number %d", len(allAccounts))
} else {
log.Infof("single account mode disabled, accounts number %d", len(allAccounts))
}
// if account doesn't have a default group
// we create 'all' group and add all peers into it
// also we create default rule with source as destination
for _, account := range allAccounts {
shouldSave := false
_, err := account.GetGroupAll()
if err != nil {
if err := addAllGroup(account); err != nil {
return nil, err
}
shouldSave = true
}
if shouldSave {
err = store.SaveAccount(account)
if err != nil {
return nil, err
}
}
}
goCacheClient := gocache.New(CacheExpirationMax, 30*time.Minute)
goCacheStore := cacheStore.NewGoCache(goCacheClient)
am.cacheManager = cache.NewLoadable[[]*idp.UserData](am.loadAccount, cache.New[[]*idp.UserData](goCacheStore))
if !isNil(am.idpManager) {
go func() {
err := am.warmupIDPCache()
if err != nil {
log.Warnf("failed warming up cache due to error: %v", err)
// todo retry?
return
}
}()
}
return am, nil
}
// UpdateAccountSettings updates Account settings.
// Only users with role UserRoleAdmin can update the account.
// User that performs the update has to belong to the account.
// Returns an updated Account
func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error) {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
}
if newSettings.PeerLoginExpiration < time.Hour {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
}
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccountByUser(userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if err != nil {
return nil, err
}
if !user.IsAdmin() {
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account")
}
oldSettings := account.Settings
if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled {
event := activity.AccountPeerLoginExpirationEnabled
if !newSettings.PeerLoginExpirationEnabled {
event = activity.AccountPeerLoginExpirationDisabled
am.peerLoginExpiry.Cancel([]string{accountID})
} else {
am.checkAndSchedulePeerLoginExpiration(account)
}
am.storeEvent(userID, accountID, accountID, event, nil)
}
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
am.storeEvent(userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
am.checkAndSchedulePeerLoginExpiration(account)
}
updatedAccount := account.UpdateSettings(newSettings)
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
return updatedAccount, nil
}
func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
log.Errorf("failed getting account %s expiring peers", account.Id)
return account.GetNextPeerExpiration()
}
expiredPeers := account.GetExpiredPeers()
var peerIDs []string
for _, peer := range expiredPeers {
peerIDs = append(peerIDs, peer.ID)
}
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
if err := am.expireAndUpdatePeers(account, expiredPeers); err != nil {
log.Errorf("failed updating account peers while expiring peers for account %s", account.Id)
return account.GetNextPeerExpiration()
}
return account.GetNextPeerExpiration()
}
}
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(account *Account) {
am.peerLoginExpiry.Cancel([]string{account.Id})
if nextRun, ok := account.GetNextPeerExpiration(); ok {
go am.peerLoginExpiry.Schedule(nextRun, account.Id, am.peerLoginExpirationJob(account.Id))
}
}
// newAccount creates a new Account with a generated ID and generated default setup keys.
// If ID is already in use (due to collision) we try one more time before returning error
func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, error) {
for i := 0; i < 2; i++ {
accountId := xid.New().String()
_, err := am.Store.GetAccount(accountId)
statusErr, _ := status.FromError(err)
if err == nil {
log.Warnf("an account with ID already exists, retrying...")
continue
} else if statusErr.Type() == status.NotFound {
newAccount := newAccountWithId(accountId, userID, domain)
am.storeEvent(userID, newAccount.Id, accountId, activity.AccountCreated, nil)
return newAccount, nil
} else {
return nil, err
}
}
return nil, status.Errorf(status.Internal, "error while creating new account")
}
func (am *DefaultAccountManager) warmupIDPCache() error {
userData, err := am.idpManager.GetAllAccounts()
if err != nil {
return err
}
log.Infof("%d entries received from IdP management", len(userData))
// If the Identity Provider does not support writing AppMetadata,
// in cases like this, we expect it to return all users in an "unset" field.
// We iterate over the users in the "unset" field, look up their AccountID in our store, and
// update their AppMetadata with the AccountID.
if unsetData, ok := userData[idp.UnsetAccountID]; ok {
for _, user := range unsetData {
accountID, err := am.Store.GetAccountByUser(user.ID)
if err == nil {
data := userData[accountID.Id]
if data == nil {
data = make([]*idp.UserData, 0, 1)
}
user.AppMetadata.WTAccountID = accountID.Id
userData[accountID.Id] = append(data, user)
}
}
}
delete(userData, idp.UnsetAccountID)
for accountID, users := range userData {
err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration()))
if err != nil {
return err
}
}
log.Infof("warmed up IDP cache with %d entries", len(userData))
return nil
}
// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
// userID doesn't have an account associated with it, one account is created
func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) {
if accountID != "" {
return am.Store.GetAccount(accountID)
} else if userID != "" {
account, err := am.GetOrCreateAccountByUser(userID, domain)
if err != nil {
return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID)
}
err = am.addAccountIDToIDPAppMeta(userID, account)
if err != nil {
return nil, err
}
return account, nil
}
return nil, status.Errorf(status.NotFound, "no valid user or account Id provided")
}
func isNil(i idp.Manager) bool {
return i == nil || reflect.ValueOf(i).IsNil()
}
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account *Account) error {
if !isNil(am.idpManager) {
// user can be nil if it wasn't found (e.g., just created)
user, err := am.lookupUserInCache(userID, account)
if err != nil {
return err
}
if user != nil && user.AppMetadata.WTAccountID == account.Id {
// it was already set, so we skip the unnecessary update
log.Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
account.Id, userID)
return nil
}
err = am.idpManager.UpdateUserAppMetadata(userID, idp.AppMetadata{WTAccountID: account.Id})
if err != nil {
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
}
// refresh cache to reflect the update
_, err = am.refreshCache(account.Id)
if err != nil {
return err
}
}
return nil
}
func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interface{}) ([]*idp.UserData, error) {
log.Debugf("account %s not found in cache, reloading", accountID)
accountIDString := fmt.Sprintf("%v", accountID)
account, err := am.Store.GetAccount(accountIDString)
if err != nil {
return nil, err
}
userData, err := am.idpManager.GetAccount(accountIDString)
if err != nil {
return nil, err
}
log.Debugf("%d entries received from IdP management", len(userData))
dataMap := make(map[string]*idp.UserData, len(userData))
for _, datum := range userData {
dataMap[datum.ID] = datum
}
matchedUserData := make([]*idp.UserData, 0)
for _, user := range account.Users {
if user.IsServiceUser {
continue
}
datum, ok := dataMap[user.Id]
if !ok {
log.Warnf("user %s not found in IDP", user.Id)
continue
}
matchedUserData = append(matchedUserData, datum)
}
return matchedUserData, nil
}
func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountID string) (*idp.UserData, error) {
data, err := am.getAccountFromCache(accountID, false)
if err != nil {
return nil, err
}
for _, datum := range data {
if datum.Email == email {
return datum, nil
}
}
return nil, nil //nolint:nilnil
}
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
users := make(map[string]struct{}, len(account.Users))
for _, user := range account.Users {
if !user.IsServiceUser {
users[user.Id] = struct{}{}
}
}
log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
userData, err := am.lookupCache(users, account.Id)
if err != nil {
return nil, err
}
for _, datum := range userData {
if datum.ID == userID {
return datum, nil
}
}
return nil, nil //nolint:nilnil
}
func (am *DefaultAccountManager) refreshCache(accountID string) ([]*idp.UserData, error) {
return am.getAccountFromCache(accountID, true)
}
// getAccountFromCache returns user data for a given account ensuring that cache load happens only once
func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceReload bool) ([]*idp.UserData, error) {
am.cacheMux.Lock()
loadingChan := am.cacheLoading[accountID]
if loadingChan == nil {
loadingChan = make(chan struct{})
am.cacheLoading[accountID] = loadingChan
am.cacheMux.Unlock()
defer func() {
am.cacheMux.Lock()
delete(am.cacheLoading, accountID)
close(loadingChan)
am.cacheMux.Unlock()
}()
if forceReload {
err := am.cacheManager.Delete(am.ctx, accountID)
if err != nil {
return nil, err
}
}
return am.cacheManager.Get(am.ctx, accountID)
}
am.cacheMux.Unlock()
log.Debugf("one request to get account %s is already running", accountID)
select {
case <-loadingChan:
// channel has been closed meaning cache was loaded => simply return from cache
return am.cacheManager.Get(am.ctx, accountID)
case <-time.After(5 * time.Second):
return nil, fmt.Errorf("timeout while waiting for account %s cache to reload", accountID)
}
}
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, accountID string) ([]*idp.UserData, error) {
data, err := am.getAccountFromCache(accountID, false)
if err != nil {
return nil, err
}
userDataMap := make(map[string]struct{})
for _, datum := range data {
userDataMap[datum.ID] = struct{}{}
}
// check whether we need to reload the cache
// the accountUsers ID list is the source of truth and all the users should be in the cache
reload := len(accountUsers) != len(data)
for user := range accountUsers {
if _, ok := userDataMap[user]; !ok {
reload = true
}
}
if reload {
// reload cache once avoiding loops
data, err = am.refreshCache(accountID)
if err != nil {
return nil, err
}
}
return data, err
}
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims,
primaryDomain bool,
) error {
account.IsDomainPrimaryAccount = primaryDomain
lowerDomain := strings.ToLower(claims.Domain)
userObj := account.Users[claims.UserId]
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
account.Domain = lowerDomain
}
// prevent updating category for different domain until admin logs in
if account.Domain == lowerDomain {
account.DomainCategory = claims.DomainCategory
}
err := am.Store.SaveAccount(account)
if err != nil {
return err
}
return nil
}
// handleExistingUserAccount handles existing User accounts and update its domain attributes.
//
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
// was previously unclassified or classified as public so N users that logged int that time, has they own account
// and peers that shouldn't be lost.
func (am *DefaultAccountManager) handleExistingUserAccount(
existingAcc *Account,
domainAcc *Account,
claims jwtclaims.AuthorizationClaims,
) error {
var err error
if domainAcc != nil && existingAcc.Id != domainAcc.Id {
err = am.updateAccountDomainAttributes(existingAcc, claims, false)
if err != nil {
return err
}
} else {
err = am.updateAccountDomainAttributes(existingAcc, claims, true)
if err != nil {
return err
}
}
// we should register the account ID to this user's metadata in our IDP manager
err = am.addAccountIDToIDPAppMeta(claims.UserId, existingAcc)
if err != nil {
return err
}
return nil
}
// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
// otherwise it will create a new account and make it primary account for the domain.
func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
if claims.UserId == "" {
return nil, fmt.Errorf("user ID is empty")
}
var (
account *Account
err error
)
lowerDomain := strings.ToLower(claims.Domain)
// if domain already has a primary account, add regular user
if domainAcc != nil {
account = domainAcc
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
} else {
account, err = am.newAccount(claims.UserId, lowerDomain)
if err != nil {
return nil, err
}
err = am.updateAccountDomainAttributes(account, claims, true)
if err != nil {
return nil, err
}
}
err = am.addAccountIDToIDPAppMeta(claims.UserId, account)
if err != nil {
return nil, err
}
am.storeEvent(claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil)
return account, nil
}
// redeemInvite checks whether user has been invited and redeems the invite
func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) error {
// only possible with the enabled IdP manager
if am.idpManager == nil {
log.Warnf("invites only work with enabled IdP manager")
return nil
}
user, err := am.lookupUserInCache(userID, account)
if err != nil {
return err
}
if user == nil {
return status.Errorf(status.NotFound, "user %s not found in the IdP", userID)
}
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
log.Infof("redeeming invite for user %s account %s", userID, account.Id)
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
// Our job is to just reload cache.
go func() {
_, err = am.refreshCache(account.Id)
if err != nil {
log.Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id)
return
}
log.Debugf("user %s of account %s redeemed invite", user.ID, account.Id)
am.storeEvent(userID, userID, account.Id, activity.UserJoined, nil)
}()
}
return nil
}
// MarkPATUsed marks a personal access token as used
func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error {
user, err := am.Store.GetUserByTokenID(tokenID)
if err != nil {
return err
}
account, err := am.Store.GetAccountByUser(user.Id)
if err != nil {
return err
}
unlock := am.Store.AcquireAccountLock(account.Id)
defer unlock()
account, err = am.Store.GetAccountByUser(user.Id)
if err != nil {
return err
}
pat, ok := account.Users[user.Id].PATs[tokenID]
if !ok {
return fmt.Errorf("token not found")
}
pat.LastUsed = time.Now().UTC()
return am.Store.SaveAccount(account)
}
// GetAccountFromPAT returns Account and User associated with a personal access token
func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, *PersonalAccessToken, error) {
if len(token) != PATLength {
return nil, nil, nil, fmt.Errorf("token has wrong length")
}
prefix := token[:len(PATPrefix)]
if prefix != PATPrefix {
return nil, nil, nil, fmt.Errorf("token has wrong prefix")
}
secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength]
encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength]
verificationChecksum, err := base62.Decode(encodedChecksum)
if err != nil {
return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
}
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
if secretChecksum != verificationChecksum {
return nil, nil, nil, fmt.Errorf("token checksum does not match")
}
hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
tokenID, err := am.Store.GetTokenIDByHashedToken(encodedHashedToken)
if err != nil {
return nil, nil, nil, err
}
user, err := am.Store.GetUserByTokenID(tokenID)
if err != nil {
return nil, nil, nil, err
}
account, err := am.Store.GetAccountByUser(user.Id)
if err != nil {
return nil, nil, nil, err
}
pat := user.PATs[tokenID]
if pat == nil {
return nil, nil, nil, fmt.Errorf("personal access token not found")
}
return account, user, pat, nil
}
// GetAccountFromToken returns an account associated with this token
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
if claims.UserId == "" {
return nil, nil, fmt.Errorf("user ID is empty")
}
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
// We override incoming domain claims to group users under a single account.
claims.Domain = am.singleAccountModeDomain
claims.DomainCategory = PrivateCategory
log.Infof("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
}
account, err := am.getAccountWithAuthorizationClaims(claims)
if err != nil {
return nil, nil, err
}
user := account.Users[claims.UserId]
if user == nil {
// this is not really possible because we got an account by user ID
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
}
if !user.IsServiceUser {
err = am.redeemInvite(account, claims.UserId)
if err != nil {
return nil, nil, err
}
}
if account.Settings.JWTGroupsEnabled {
if account.Settings.JWTGroupsClaimName == "" {
log.Errorf("JWT groups are enabled but no claim name is set")
return account, user, nil
}
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
if slice, ok := claim.([]interface{}); ok {
var groupsNames []string
for _, item := range slice {
if g, ok := item.(string); ok {
groupsNames = append(groupsNames, g)
} else {
log.Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item)
}
}
oldGroups := make([]string, len(user.AutoGroups))
copy(oldGroups, user.AutoGroups)
// if groups were added or modified, save the account
if account.SetJWTGroups(claims.UserId, groupsNames) {
if account.Settings.GroupsPropagationEnabled {
if user, err := account.FindUser(claims.UserId); err == nil {
addNewGroups := difference(user.AutoGroups, oldGroups)
removeOldGroups := difference(oldGroups, user.AutoGroups)
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
account.Network.IncSerial()
if err := am.Store.SaveAccount(account); err != nil {
log.Errorf("failed to save account: %v", err)
} else {
am.updateAccountPeers(account)
for _, g := range addNewGroups {
if group := account.GetGroup(g); group != nil {
am.storeEvent(user.Id, user.Id, account.Id, activity.GroupAddedToUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
for _, g := range removeOldGroups {
if group := account.GetGroup(g); group != nil {
am.storeEvent(user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
}
}
} else {
if err := am.Store.SaveAccount(account); err != nil {
log.Errorf("failed to save account: %v", err)
}
}
}
} else {
log.Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName)
}
} else {
log.Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName)
}
}
return account, user, nil
}
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
// if domain is of the PrivateCategory category, it will evaluate
// if account is new, existing or if there is another account with the same domain
//
// Use cases:
//
// New user + New account + New domain -> create account, user role = admin (if private domain, index domain)
//
// New user + New account + Existing Private Domain -> add user to the existing account, user role = regular (not admin)
//
// New user + New account + Existing Public Domain -> create account, user role = admin
//
// Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain)
//
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
//
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) {
if claims.UserId == "" {
return nil, fmt.Errorf("user ID is empty")
}
// if Account ID is part of the claims
// it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountByUserOrAccountID(claims.UserId, claims.AccountId, claims.Domain)
} else if claims.AccountId != "" {
accountFromID, err := am.Store.GetAccount(claims.AccountId)
if err != nil {
return nil, err
}
if _, ok := accountFromID.Users[claims.UserId]; !ok {
return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
}
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
return accountFromID, nil
}
}
unlock := am.Store.AcquireGlobalLock()
defer unlock()
// We checked if the domain has a primary account already
domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain)
if err != nil {
// if NotFound we are good to continue, otherwise return error
e, ok := status.FromError(err)
if !ok || e.Type() != status.NotFound {
return nil, err
}
}
account, err := am.Store.GetAccountByUser(claims.UserId)
if err == nil {
err = am.handleExistingUserAccount(account, domainAccount, claims)
if err != nil {
return nil, err
}
return account, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
return am.handleNewUserAccount(domainAccount, claims)
} else {
// other error
return nil, err
}
}
// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
return am.peersUpdateManager.GetAllConnectedPeers(), nil
}
func isDomainValid(domain string) bool {
re := regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
return re.Match([]byte(domain))
}
// GetDNSDomain returns the configured dnsDomain
func (am *DefaultAccountManager) GetDNSDomain() string {
return am.dnsDomain
}
// addAllGroup to account object if it doesn't exists
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
allGroup := &Group{
ID: xid.New().String(),
Name: "All",
Issued: GroupIssuedAPI,
}
for _, peer := range account.Peers {
allGroup.Peers = append(allGroup.Peers, peer.ID)
}
account.Groups = map[string]*Group{allGroup.ID: allGroup}
defaultRule := &Rule{
ID: xid.New().String(),
Name: DefaultRuleName,
Description: DefaultRuleDescription,
Disabled: false,
Source: []string{allGroup.ID},
Destination: []string{allGroup.ID},
}
account.Rules = map[string]*Rule{defaultRule.ID: defaultRule}
// TODO: after migration we need to drop rule and create policy directly
defaultPolicy, err := RuleToPolicy(defaultRule)
if err != nil {
return fmt.Errorf("convert rule to policy: %w", err)
}
account.Policies = []*Policy{defaultPolicy}
}
return nil
}
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
func newAccountWithId(accountID, userID, domain string) *Account {
log.Debugf("creating new account")
network := NewNetwork()
peers := make(map[string]*Peer)
users := make(map[string]*User)
routes := make(map[string]*route.Route)
setupKeys := map[string]*SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userID] = NewAdminUser(userID)
dnsSettings := DNSSettings{
DisabledManagementGroups: make([]string, 0),
}
log.Debugf("created new account %s", accountID)
acc := &Account{
Id: accountID,
SetupKeys: setupKeys,
Network: network,
Peers: peers,
Users: users,
CreatedBy: userID,
Domain: domain,
Routes: routes,
NameServerGroups: nameServersGroups,
DNSSettings: dnsSettings,
Settings: &Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: DefaultPeerLoginExpiration,
},
}
if err := addAllGroup(acc); err != nil {
log.Errorf("error adding all group to account %s: %v", acc.Id, err)
}
return acc
}