package server import ( "context" "crypto/sha256" b64 "encoding/base64" "errors" "fmt" "hash/crc32" "math/rand" "net" "net/netip" "reflect" "regexp" "slices" "strings" "sync" "time" "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" "github.com/hashicorp/go-multierror" "github.com/miekg/dns" gocache "github.com/patrickmn/go-cache" "github.com/rs/xid" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/geolocation" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrated_validator" "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/route" ) const ( PublicCategory = "public" PrivateCategory = "private" UnknownCategory = "unknown" CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days DefaultPeerLoginExpiration = 24 * time.Hour DefaultPeerInactivityExpiration = 10 * time.Minute emptyUserID = "empty user ID in claims" errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) type userLoggedInOnce bool type ExternalCacheManager cache.CacheInterface[*idp.UserData] func cacheEntryExpiration() time.Duration { r := rand.Intn(int(CacheExpirationMax.Milliseconds()-CacheExpirationMin.Milliseconds())) + int(CacheExpirationMin.Milliseconds()) return time.Duration(r) * time.Millisecond } type AccountManager interface { GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error) GetAccount(ctx context.Context, accountID string) (*Account, error) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error) CreateUser(ctx context.Context, accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error) DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error GetAccountInfoFromPAT(ctx context.Context, token string) (*User, *PersonalAccessToken, string, string, error) DeleteAccount(ctx context.Context, accountID, userID string) error MarkPATUsed(ctx context.Context, tokenID string) error GetUserByID(ctx context.Context, id string) (*User, error) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) ListUsers(ctx context.Context, accountID string) ([]*User, error) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) GetPeerNetwork(ctx context.Context, peerID string) (*Network, error) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) GetDNSDomain() string StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManager() idp.Manager UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error } type DefaultAccountManager struct { Store Store // cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID cacheMux sync.Mutex // cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded cacheLoading map[string]chan struct{} peersUpdateManager *PeersUpdateManager idpManager idp.Manager cacheManager cache.CacheInterface[[]*idp.UserData] externalCacheManager ExternalCacheManager ctx context.Context eventStore activity.Store geo *geolocation.Geolocation requestBuffer *AccountRequestBuffer // singleAccountMode indicates whether the instance has a single account. // If true, then every new user will end up under the same account. // This value will be set to false if management service has more than one account. singleAccountMode bool // singleAccountModeDomain is a domain to use in singleAccountMode setup singleAccountModeDomain string // dnsDomain is used for peer resolution. This is appended to the peer's name dnsDomain string peerLoginExpiry Scheduler peerInactivityExpiry Scheduler // userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account userDeleteFromIDPEnabled bool integratedPeerValidator integrated_validator.IntegratedValidator metrics telemetry.AppMetrics } // Settings represents Account settings structure that can be modified via API and Dashboard type Settings struct { // PeerLoginExpirationEnabled globally enables or disables peer login expiration PeerLoginExpirationEnabled bool // PeerLoginExpiration is a setting that indicates when peer login expires. // Applies to all peers that have Peer.LoginExpirationEnabled set to true. PeerLoginExpiration time.Duration // PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration PeerInactivityExpirationEnabled bool // PeerInactivityExpiration is a setting that indicates when peer inactivity expires. // Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true. PeerInactivityExpiration time.Duration // RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements RegularUsersViewBlocked bool // GroupsPropagationEnabled allows to propagate auto groups from the user to the peer GroupsPropagationEnabled bool // JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName // and add it to account groups. JWTGroupsEnabled bool // JWTGroupsClaimName from which we extract groups name to add it to account groups JWTGroupsClaimName string // JWTAllowGroups list of groups to which users are allowed access JWTAllowGroups []string `gorm:"serializer:json"` // Extra is a dictionary of Account settings Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"` } // Copy copies the Settings struct func (s *Settings) Copy() *Settings { settings := &Settings{ PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled, PeerLoginExpiration: s.PeerLoginExpiration, JWTGroupsEnabled: s.JWTGroupsEnabled, JWTGroupsClaimName: s.JWTGroupsClaimName, GroupsPropagationEnabled: s.GroupsPropagationEnabled, JWTAllowGroups: s.JWTAllowGroups, RegularUsersViewBlocked: s.RegularUsersViewBlocked, PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled, PeerInactivityExpiration: s.PeerInactivityExpiration, } if s.Extra != nil { settings.Extra = s.Extra.Copy() } return settings } // Account represents a unique account of the system type Account struct { // we have to name column to aid as it collides with Network.Id when work with associations Id string `gorm:"primaryKey"` // User.Id it was created by CreatedBy string CreatedAt time.Time Domain string `gorm:"index"` DomainCategory string IsDomainPrimaryAccount bool SetupKeys map[string]*SetupKey `gorm:"-"` SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"` Network *Network `gorm:"embedded;embeddedPrefix:network_"` Peers map[string]*nbpeer.Peer `gorm:"-"` PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` Users map[string]*User `gorm:"-"` UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` Groups map[string]*nbgroup.Group `gorm:"-"` GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"` Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` Routes map[route.ID]*route.Route `gorm:"-"` RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"` NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"` DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` // Settings is a dictionary of Account settings Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` } // Subclass used in gorm to only load settings and not whole account type AccountSettings struct { Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` } // Subclass used in gorm to only load network and not whole account type AccountNetwork struct { Network *Network `gorm:"embedded;embeddedPrefix:network_"` } // AccountDNSSettings used in gorm to only load dns settings and not whole account type AccountDNSSettings struct { DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` } type UserPermissions struct { DashboardView string `json:"dashboard_view"` } type UserInfo struct { ID string `json:"id"` Email string `json:"email"` Name string `json:"name"` Role string `json:"role"` AutoGroups []string `json:"auto_groups"` Status string `json:"-"` IsServiceUser bool `json:"is_service_user"` IsBlocked bool `json:"is_blocked"` NonDeletable bool `json:"non_deletable"` LastLogin time.Time `json:"last_login"` Issued string `json:"issued"` IntegrationReference integration_reference.IntegrationReference `json:"-"` Permissions UserPermissions `json:"permissions"` } // getRoutesToSync returns the enabled routes for the peer ID and the routes // from the ACL peers that have distribution groups associated with the peer ID. // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. func (a *Account) getRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route { routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID) peerRoutesMembership := make(lookupMap) for _, r := range append(routes, peerDisabledRoutes...) { peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} } groupListMap := a.getPeerGroups(peerID) for _, peer := range aclPeers { activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID) groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap) filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) routes = append(routes, filteredRoutes...) } return routes } // filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route { var filteredRoutes []*route.Route for _, r := range routes { _, found := peerMemberships[string(r.GetHAUniqueID())] if !found { filteredRoutes = append(filteredRoutes, r) } } return filteredRoutes } // filterRoutesByGroups returns a list with routes that have distribution groups in the group's map func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap lookupMap) []*route.Route { var filteredRoutes []*route.Route for _, r := range routes { for _, groupID := range r.Groups { _, found := groupListMap[groupID] if found { filteredRoutes = append(filteredRoutes, r) break } } } return filteredRoutes } // getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. // If the given is not a routing peer, then the lists are empty. func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { peer := a.GetPeer(peerID) if peer == nil { log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id) return enabledRoutes, disabledRoutes } // currently we support only linux routing peers if peer.Meta.GoOS != "linux" { return enabledRoutes, disabledRoutes } seenRoute := make(map[route.ID]struct{}) takeRoute := func(r *route.Route, id string) { if _, ok := seenRoute[r.ID]; ok { return } seenRoute[r.ID] = struct{}{} if r.Enabled { r.Peer = peer.Key enabledRoutes = append(enabledRoutes, r) return } disabledRoutes = append(disabledRoutes, r) } for _, r := range a.Routes { for _, groupID := range r.PeerGroups { group := a.GetGroup(groupID) if group == nil { log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) continue } for _, id := range group.Peers { if id != peerID { continue } newPeerRoute := r.Copy() newPeerRoute.Peer = id newPeerRoute.PeerGroups = nil newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map takeRoute(newPeerRoute, id) break } } if r.Peer == peerID { takeRoute(r.Copy(), peerID) } } return enabledRoutes, disabledRoutes } // GetRoutesByPrefixOrDomains return list of routes by account and route prefix func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route { var routes []*route.Route for _, r := range a.Routes { dynamic := r.IsDynamic() if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() || !dynamic && r.Network.String() == prefix.String() { routes = append(routes, r) } } return routes } // GetGroup returns a group by ID if exists, nil otherwise func (a *Account) GetGroup(groupID string) *nbgroup.Group { return a.Groups[groupID] } // GetPeerNetworkMap returns the networkmap for the given peer ID. func (a *Account) GetPeerNetworkMap( ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, validatedPeersMap map[string]struct{}, metrics *telemetry.AccountManagerMetrics, ) *NetworkMap { start := time.Now() peer := a.Peers[peerID] if peer == nil { return &NetworkMap{ Network: a.Network.Copy(), } } if _, ok := validatedPeersMap[peerID]; !ok { return &NetworkMap{ Network: a.Network.Copy(), } } aclPeers, firewallRules := a.getPeerConnectionResources(ctx, peerID, validatedPeersMap) // exclude expired peers var peersToConnect []*nbpeer.Peer var expiredPeers []*nbpeer.Peer for _, p := range aclPeers { expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration) if a.Settings.PeerLoginExpirationEnabled && expired { expiredPeers = append(expiredPeers, p) continue } peersToConnect = append(peersToConnect, p) } routesUpdate := a.getRoutesToSync(ctx, peerID, peersToConnect) routesFirewallRules := a.getPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) dnsUpdate := nbdns.Config{ ServiceEnable: dnsManagementStatus, } if dnsManagementStatus { var zones []nbdns.CustomZone if peersCustomZone.Domain != "" { zones = append(zones, peersCustomZone) } dnsUpdate.CustomZones = zones dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) } nm := &NetworkMap{ Peers: peersToConnect, Network: a.Network.Copy(), Routes: routesUpdate, DNSConfig: dnsUpdate, OfflinePeers: expiredPeers, FirewallRules: firewallRules, RoutesFirewallRules: routesFirewallRules, } if metrics != nil { objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules)) metrics.CountNetworkMapObjects(objectCount) metrics.CountGetPeerNetworkMapDuration(time.Since(start)) if objectCount > 5000 { log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+ "peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d", a.Id, objectCount, len(peersToConnect), len(expiredPeers), len(routesUpdate), len(firewallRules)) } } return nm } func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone { var merr *multierror.Error if dnsDomain == "" { log.WithContext(ctx).Error("no dns domain is set, returning empty zone") return nbdns.CustomZone{} } customZone := nbdns.CustomZone{ Domain: dns.Fqdn(dnsDomain), Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)), } domainSuffix := "." + dnsDomain var sb strings.Builder for _, peer := range a.Peers { if peer.DNSLabel == "" { merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name)) continue } sb.Grow(len(peer.DNSLabel) + len(domainSuffix)) sb.WriteString(peer.DNSLabel) sb.WriteString(domainSuffix) customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ Name: sb.String(), Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: defaultTTL, RData: peer.IP.String(), }) sb.Reset() } go func() { if merr != nil { log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr) } }() return customZone } // GetExpiredPeers returns peers that have been expired func (a *Account) GetExpiredPeers() []*nbpeer.Peer { var peers []*nbpeer.Peer for _, peer := range a.GetPeersWithExpiration() { expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration) if expired { peers = append(peers, peer) } } return peers } // GetNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. // If there is no peer that expires this function returns false and a duration of 0. // This function only considers peers that haven't been expired yet and that are connected. func (a *Account) GetNextPeerExpiration() (time.Duration, bool) { peersWithExpiry := a.GetPeersWithExpiration() if len(peersWithExpiry) == 0 { return 0, false } var nextExpiry *time.Duration for _, peer := range peersWithExpiry { // consider only connected peers because others will require login on connecting to the management server if peer.Status.LoginExpired || !peer.Status.Connected { continue } _, duration := peer.LoginExpired(a.Settings.PeerLoginExpiration) if nextExpiry == nil || duration < *nextExpiry { // if expiration is below 1s return 1s duration // this avoids issues with ticker that can't be set to < 0 if duration < time.Second { return time.Second, true } nextExpiry = &duration } } if nextExpiry == nil { return 0, false } return *nextExpiry, true } // GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer { peers := make([]*nbpeer.Peer, 0) for _, peer := range a.Peers { if peer.LoginExpirationEnabled && peer.AddedWithSSOLogin() { peers = append(peers, peer) } } return peers } // GetInactivePeers returns peers that have been expired by inactivity func (a *Account) GetInactivePeers() []*nbpeer.Peer { var peers []*nbpeer.Peer for _, inactivePeer := range a.GetPeersWithInactivity() { inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration) if inactive { peers = append(peers, inactivePeer) } } return peers } // GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. // If there is no peer that expires this function returns false and a duration of 0. // This function only considers peers that haven't been expired yet and that are not connected. func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) { peersWithExpiry := a.GetPeersWithInactivity() if len(peersWithExpiry) == 0 { return 0, false } var nextExpiry *time.Duration for _, peer := range peersWithExpiry { if peer.Status.LoginExpired || peer.Status.Connected { continue } _, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration) if nextExpiry == nil || duration < *nextExpiry { // if expiration is below 1s return 1s duration // this avoids issues with ticker that can't be set to < 0 if duration < time.Second { return time.Second, true } nextExpiry = &duration } } if nextExpiry == nil { return 0, false } return *nextExpiry, true } // GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer { peers := make([]*nbpeer.Peer, 0) for _, peer := range a.Peers { if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() { peers = append(peers, peer) } } return peers } // GetPeers returns a list of all Account peers func (a *Account) GetPeers() []*nbpeer.Peer { var peers []*nbpeer.Peer for _, peer := range a.Peers { peers = append(peers, peer) } return peers } // UpdateSettings saves new account settings func (a *Account) UpdateSettings(update *Settings) *Account { a.Settings = update.Copy() return a } // UpdatePeer saves new or replaces existing peer func (a *Account) UpdatePeer(update *nbpeer.Peer) { a.Peers[update.ID] = update } // DeletePeer deletes peer from the account cleaning up all the references func (a *Account) DeletePeer(peerID string) { // delete peer from groups for _, g := range a.Groups { for i, pk := range g.Peers { if pk == peerID { g.Peers = append(g.Peers[:i], g.Peers[i+1:]...) break } } } for _, r := range a.Routes { if r.Peer == peerID { r.Enabled = false r.Peer = "" } } delete(a.Peers, peerID) a.Network.IncSerial() } // FindPeerByPubKey looks for a Peer by provided WireGuard public key in the Account or returns error if it wasn't found. // It will return an object copy of the peer. func (a *Account) FindPeerByPubKey(peerPubKey string) (*nbpeer.Peer, error) { for _, peer := range a.Peers { if peer.Key == peerPubKey { return peer.Copy(), nil } } return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", peerPubKey) } // FindUserPeers returns a list of peers that user owns (created) func (a *Account) FindUserPeers(userID string) ([]*nbpeer.Peer, error) { peers := make([]*nbpeer.Peer, 0) for _, peer := range a.Peers { if peer.UserID == userID { peers = append(peers, peer) } } return peers, nil } // FindUser looks for a given user in the Account or returns error if user wasn't found. func (a *Account) FindUser(userID string) (*User, error) { user := a.Users[userID] if user == nil { return nil, status.Errorf(status.NotFound, "user %s not found", userID) } return user, nil } // FindGroupByName looks for a given group in the Account by name or returns error if the group wasn't found. func (a *Account) FindGroupByName(groupName string) (*nbgroup.Group, error) { for _, group := range a.Groups { if group.Name == groupName { return group, nil } } return nil, status.Errorf(status.NotFound, "group %s not found", groupName) } // FindSetupKey looks for a given SetupKey in the Account or returns error if it wasn't found. func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) { key := a.SetupKeys[setupKey] if key == nil { return nil, status.Errorf(status.NotFound, "setup key not found") } return key, nil } // GetPeerGroupsList return with the list of groups ID. func (a *Account) GetPeerGroupsList(peerID string) []string { var grps []string for groupID, group := range a.Groups { for _, id := range group.Peers { if id == peerID { grps = append(grps, groupID) break } } } return grps } func (a *Account) getPeerDNSManagementStatus(peerID string) bool { peerGroups := a.getPeerGroups(peerID) enabled := true for _, groupID := range a.DNSSettings.DisabledManagementGroups { _, found := peerGroups[groupID] if found { enabled = false break } } return enabled } func (a *Account) getPeerGroups(peerID string) lookupMap { groupList := make(lookupMap) for groupID, group := range a.Groups { for _, id := range group.Peers { if id == peerID { groupList[groupID] = struct{}{} break } } } return groupList } func (a *Account) getTakenIPs() []net.IP { var takenIps []net.IP for _, existingPeer := range a.Peers { takenIps = append(takenIps, existingPeer.IP) } return takenIps } func (a *Account) getPeerDNSLabels() lookupMap { existingLabels := make(lookupMap) for _, peer := range a.Peers { if peer.DNSLabel != "" { existingLabels[peer.DNSLabel] = struct{}{} } } return existingLabels } func (a *Account) Copy() *Account { peers := map[string]*nbpeer.Peer{} for id, peer := range a.Peers { peers[id] = peer.Copy() } users := map[string]*User{} for id, user := range a.Users { users[id] = user.Copy() } setupKeys := map[string]*SetupKey{} for id, key := range a.SetupKeys { setupKeys[id] = key.Copy() } groups := map[string]*nbgroup.Group{} for id, group := range a.Groups { groups[id] = group.Copy() } policies := []*Policy{} for _, policy := range a.Policies { policies = append(policies, policy.Copy()) } routes := map[route.ID]*route.Route{} for id, r := range a.Routes { routes[id] = r.Copy() } nsGroups := map[string]*nbdns.NameServerGroup{} for id, nsGroup := range a.NameServerGroups { nsGroups[id] = nsGroup.Copy() } dnsSettings := a.DNSSettings.Copy() var settings *Settings if a.Settings != nil { settings = a.Settings.Copy() } postureChecks := []*posture.Checks{} for _, postureCheck := range a.PostureChecks { postureChecks = append(postureChecks, postureCheck.Copy()) } return &Account{ Id: a.Id, CreatedBy: a.CreatedBy, CreatedAt: a.CreatedAt, Domain: a.Domain, DomainCategory: a.DomainCategory, IsDomainPrimaryAccount: a.IsDomainPrimaryAccount, SetupKeys: setupKeys, Network: a.Network.Copy(), Peers: peers, Users: users, Groups: groups, Policies: policies, Routes: routes, NameServerGroups: nsGroups, DNSSettings: dnsSettings, PostureChecks: postureChecks, Settings: settings, } } func (a *Account) GetGroupAll() (*nbgroup.Group, error) { for _, g := range a.Groups { if g.Name == "All" { return g, nil } } return nil, fmt.Errorf("no group ALL found") } // GetPeer looks up a Peer by ID func (a *Account) GetPeer(peerID string) *nbpeer.Peer { return a.Peers[peerID] } // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. // Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups, // newly groups to create and an error if any occurred. func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) { existedGroupsByName := make(map[string]*nbgroup.Group) for _, group := range groups { existedGroupsByName[group.Name] = group } newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups) groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap)) groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames) // If no groups are added or removed, we should not sync account if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { return false, nil, nil, nil } newGroupsToCreate := make([]*nbgroup.Group, 0) var modified bool for _, name := range groupsToAdd { group, exists := existedGroupsByName[name] if !exists { group = &nbgroup.Group{ ID: xid.New().String(), AccountID: user.AccountID, Name: name, Issued: nbgroup.GroupIssuedJWT, } newGroupsToCreate = append(newGroupsToCreate, group) } if group.Issued == nbgroup.GroupIssuedJWT { newUserAutoGroups = append(newUserAutoGroups, group.ID) modified = true } } for name, id := range jwtGroupsMap { if !slices.Contains(groupsToRemove, name) { newUserAutoGroups = append(newUserAutoGroups, id) continue } modified = true } return modified, newUserAutoGroups, newGroupsToCreate, nil } // UserGroupsAddToPeers adds groups to all peers of user func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) { userPeers := make(map[string]struct{}) for pid, peer := range a.Peers { if peer.UserID == userID { userPeers[pid] = struct{}{} } } for _, gid := range groups { group, ok := a.Groups[gid] if !ok { continue } groupPeers := make(map[string]struct{}) for _, pid := range group.Peers { groupPeers[pid] = struct{}{} } for pid := range userPeers { groupPeers[pid] = struct{}{} } group.Peers = group.Peers[:0] for pid := range groupPeers { group.Peers = append(group.Peers, pid) } } } // UserGroupsRemoveFromPeers removes groups from all peers of user func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { for _, gid := range groups { group, ok := a.Groups[gid] if !ok || group.Name == "All" { continue } update := make([]string, 0, len(group.Peers)) for _, pid := range group.Peers { peer, ok := a.Peers[pid] if !ok { continue } if peer.UserID != userID { update = append(update, pid) } } group.Peers = update } } // BuildManager creates a new DefaultAccountManager with a provided Store func BuildManager( ctx context.Context, store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation, userDeleteFromIDPEnabled bool, integratedPeerValidator integrated_validator.IntegratedValidator, metrics telemetry.AppMetrics, ) (*DefaultAccountManager, error) { am := &DefaultAccountManager{ Store: store, geo: geo, peersUpdateManager: peersUpdateManager, idpManager: idpManager, ctx: context.Background(), cacheMux: sync.Mutex{}, cacheLoading: map[string]chan struct{}{}, dnsDomain: dnsDomain, eventStore: eventStore, peerLoginExpiry: NewDefaultScheduler(), peerInactivityExpiry: NewDefaultScheduler(), userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, integratedPeerValidator: integratedPeerValidator, metrics: metrics, requestBuffer: NewAccountRequestBuffer(ctx, store), } allAccounts := store.GetAllAccounts(ctx) // enable single account mode only if configured by user and number of existing accounts is not grater than 1 am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1 if am.singleAccountMode { if !isDomainValid(singleAccountModeDomain) { return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) } am.singleAccountModeDomain = singleAccountModeDomain log.WithContext(ctx).Infof("single account mode enabled, accounts number %d", len(allAccounts)) } else { log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", len(allAccounts)) } // if account doesn't have a default group // we create 'all' group and add all peers into it // also we create default rule with source as destination for _, account := range allAccounts { shouldSave := false _, err := account.GetGroupAll() if err != nil { if err := addAllGroup(account); err != nil { return nil, err } shouldSave = true } if shouldSave { err = store.SaveAccount(ctx, account) if err != nil { return nil, err } } } goCacheClient := gocache.New(CacheExpirationMax, 30*time.Minute) goCacheStore := cacheStore.NewGoCache(goCacheClient) am.cacheManager = cache.NewLoadable[[]*idp.UserData](am.loadAccount, cache.New[[]*idp.UserData](goCacheStore)) // TODO: what is max expiration time? Should be quite long am.externalCacheManager = cache.New[*idp.UserData]( cacheStore.NewGoCache(goCacheClient), ) if !isNil(am.idpManager) { go func() { err := am.warmupIDPCache(ctx) if err != nil { log.WithContext(ctx).Warnf("failed warming up cache due to error: %v", err) // todo retry? return } }() } am.integratedPeerValidator.SetPeerInvalidationListener(func(accountID string) { am.onPeersInvalidated(ctx, accountID) }) return am, nil } func (am *DefaultAccountManager) GetExternalCacheManager() ExternalCacheManager { return am.externalCacheManager } func (am *DefaultAccountManager) GetIdpManager() idp.Manager { return am.idpManager } // UpdateAccountSettings updates Account settings. // Only users with role UserRoleAdmin can update the account. // User that performs the update has to belong to the account. // Returns an updated Account func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") } if newSettings.PeerLoginExpiration < time.Hour { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } user, err := account.FindUser(userID) if err != nil { return nil, err } if !user.HasAdminPower() { return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") } err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID) if err != nil { return nil, err } oldSettings := account.Settings if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled { event := activity.AccountPeerLoginExpirationEnabled if !newSettings.PeerLoginExpirationEnabled { event = activity.AccountPeerLoginExpirationDisabled am.peerLoginExpiry.Cancel(ctx, []string{accountID}) } else { am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } am.StoreEvent(ctx, userID, accountID, accountID, event, nil) } if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) if err != nil { return nil, err } updatedAccount := account.UpdateSettings(newSettings) err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } return updatedAccount, nil } func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error { if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { event := activity.AccountPeerInactivityExpirationEnabled if !newSettings.PeerInactivityExpirationEnabled { event = activity.AccountPeerInactivityExpirationDisabled am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) } else { am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } am.StoreEvent(ctx, userID, accountID, accountID, event, nil) } if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } return nil } func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { expiredPeers, err := am.getExpiredPeers(ctx, accountID) if err != nil { return 0, false } var peerIDs []string for _, peer := range expiredPeers { peerIDs = append(peerIDs, peer.ID) } log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), accountID) if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil { log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", accountID) return 0, false } return am.getNextPeerExpiration(ctx, accountID) } } func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) { am.peerLoginExpiry.Cancel(ctx, []string{accountID}) if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok { go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID)) } } // peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { inactivePeers, err := am.getInactivePeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID) return 0, false } var peerIDs []string for _, peer := range inactivePeers { peerIDs = append(peerIDs, peer.ID) } log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), accountID) if err := am.expireAndUpdatePeers(ctx, accountID, inactivePeers); err != nil { log.Errorf("failed updating account peers while expiring peers for account %s", accountID) return 0, false } return am.getNextInactivePeerExpiration(ctx, accountID) } } // checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, accountID string) { am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) if nextRun, ok := am.getNextInactivePeerExpiration(ctx, accountID); ok { go am.peerInactivityExpiry.Schedule(ctx, nextRun, accountID, am.peerInactivityExpirationJob(ctx, accountID)) } } // newAccount creates a new Account with a generated ID and generated default setup keys. // If ID is already in use (due to collision) we try one more time before returning error func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) { for i := 0; i < 2; i++ { accountId := xid.New().String() _, err := am.Store.GetAccount(ctx, accountId) statusErr, _ := status.FromError(err) switch { case err == nil: log.WithContext(ctx).Warnf("an account with ID already exists, retrying...") continue case statusErr.Type() == status.NotFound: newAccount := newAccountWithId(ctx, accountId, userID, domain) am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil) return newAccount, nil default: return nil, err } } return nil, status.Errorf(status.Internal, "error while creating new account") } func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error { userData, err := am.idpManager.GetAllAccounts(ctx) if err != nil { return err } log.WithContext(ctx).Infof("%d entries received from IdP management", len(userData)) // If the Identity Provider does not support writing AppMetadata, // in cases like this, we expect it to return all users in an "unset" field. // We iterate over the users in the "unset" field, look up their AccountID in our store, and // update their AppMetadata with the AccountID. if unsetData, ok := userData[idp.UnsetAccountID]; ok { for _, user := range unsetData { accountID, err := am.Store.GetAccountByUser(ctx, user.ID) if err == nil { data := userData[accountID.Id] if data == nil { data = make([]*idp.UserData, 0, 1) } user.AppMetadata.WTAccountID = accountID.Id userData[accountID.Id] = append(data, user) } } } delete(userData, idp.UnsetAccountID) rcvdUsers := 0 for accountID, users := range userData { rcvdUsers += len(users) err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration())) if err != nil { return err } } log.WithContext(ctx).Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(userData)) return nil } // DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } user, err := account.FindUser(userID) if err != nil { return err } if !user.HasAdminPower() { return status.Errorf(status.PermissionDenied, "user is not allowed to delete account") } if user.Role != UserRoleOwner { return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account") } for _, otherUser := range account.Users { if otherUser.IsServiceUser { continue } if otherUser.Id == userID { continue } _, deleteUserErr := am.deleteRegularUser(ctx, accountID, userID, otherUser.Id) if deleteUserErr != nil { return deleteUserErr } } _, err = am.deleteRegularUser(ctx, accountID, userID, userID) if err != nil { log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err) return err } err = am.Store.DeleteAccount(ctx, account) if err != nil { log.WithContext(ctx).Errorf("failed deleting account %s. error: %s", accountID, err) return err } // cancel peer login expiry job am.peerLoginExpiry.Cancel(ctx, []string{account.Id}) log.WithContext(ctx).Debugf("account %s deleted", accountID) return nil } // AccountExists checks if an account exists. func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { return am.Store.AccountExists(ctx, LockingStrengthShare, accountID) } // GetAccountIDByUserID retrieves the account ID based on the userID provided. // If user does have an account, it returns the user's account ID. // If the user doesn't have an account, it creates one using the provided domain. // Returns the account ID or an error if none is found or created. func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) { if userID == "" { return "", status.Errorf(status.NotFound, "no valid userID provided") } accountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) if err != nil { return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) } if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil { return "", err } return account.Id, nil } return "", err } return accountID, nil } func isNil(i idp.Manager) bool { return i == nil || reflect.ValueOf(i).IsNil() } // addAccountIDToIDPAppMeta update user's app metadata in idp manager func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) { // user can be nil if it wasn't found (e.g., just created) user, err := am.lookupUserInCache(ctx, userID, accountID) if err != nil { return err } if user != nil && user.AppMetadata.WTAccountID == accountID { // it was already set, so we skip the unnecessary update log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s", accountID, userID) return nil } err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) if err != nil { return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err) } // refresh cache to reflect the update _, err = am.refreshCache(ctx, accountID) if err != nil { return err } } return nil } func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID interface{}) ([]*idp.UserData, error) { log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID) accountIDString := fmt.Sprintf("%v", accountID) account, err := am.Store.GetAccount(ctx, accountIDString) if err != nil { return nil, err } userData, err := am.idpManager.GetAccount(ctx, accountIDString) if err != nil { return nil, err } log.WithContext(ctx).Debugf("%d entries received from IdP management", len(userData)) dataMap := make(map[string]*idp.UserData, len(userData)) for _, datum := range userData { dataMap[datum.ID] = datum } matchedUserData := make([]*idp.UserData, 0) for _, user := range account.Users { if user.IsServiceUser { continue } datum, ok := dataMap[user.Id] if !ok { log.WithContext(ctx).Warnf("user %s not found in IDP", user.Id) continue } matchedUserData = append(matchedUserData, datum) } return matchedUserData, nil } func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, email string, accountID string) (*idp.UserData, error) { data, err := am.getAccountFromCache(ctx, accountID, false) if err != nil { return nil, err } for _, datum := range data { if datum.Email == email { return datum, nil } } return nil, nil //nolint:nilnil } // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) { accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } users := make(map[string]userLoggedInOnce, len(accountUsers)) // ignore service users and users provisioned by integrations than are never logged in for _, user := range accountUsers { if user.IsServiceUser { continue } if user.Issued == UserIssuedIntegration { continue } users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero()) } log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, accountID) userData, err := am.lookupCache(ctx, users, accountID) if err != nil { return nil, err } for _, datum := range userData { if datum.ID == userID { return datum, nil } } // add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP, // or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID) return nil, err } key := user.IntegrationReference.CacheKey(accountID, userID) ud, err := am.externalCacheManager.Get(am.ctx, key) if err != nil { log.WithContext(ctx).Debugf("failed to get externalCache for key: %s, error: %s", key, err) } return ud, nil } func (am *DefaultAccountManager) refreshCache(ctx context.Context, accountID string) ([]*idp.UserData, error) { return am.getAccountFromCache(ctx, accountID, true) } // getAccountFromCache returns user data for a given account ensuring that cache load happens only once func (am *DefaultAccountManager) getAccountFromCache(ctx context.Context, accountID string, forceReload bool) ([]*idp.UserData, error) { am.cacheMux.Lock() loadingChan := am.cacheLoading[accountID] if loadingChan == nil { loadingChan = make(chan struct{}) am.cacheLoading[accountID] = loadingChan am.cacheMux.Unlock() defer func() { am.cacheMux.Lock() delete(am.cacheLoading, accountID) close(loadingChan) am.cacheMux.Unlock() }() if forceReload { err := am.cacheManager.Delete(am.ctx, accountID) if err != nil { return nil, err } } return am.cacheManager.Get(am.ctx, accountID) } am.cacheMux.Unlock() log.WithContext(ctx).Debugf("one request to get account %s is already running", accountID) select { case <-loadingChan: // channel has been closed meaning cache was loaded => simply return from cache return am.cacheManager.Get(am.ctx, accountID) case <-time.After(5 * time.Second): return nil, fmt.Errorf("timeout while waiting for account %s cache to reload", accountID) } } func (am *DefaultAccountManager) lookupCache(ctx context.Context, accountUsers map[string]userLoggedInOnce, accountID string) ([]*idp.UserData, error) { var data []*idp.UserData var err error maxAttempts := 2 data, err = am.getAccountFromCache(ctx, accountID, false) if err != nil { return nil, err } for attempt := 1; attempt <= maxAttempts; attempt++ { if am.isCacheFresh(ctx, accountUsers, data) { return data, nil } if attempt > 1 { time.Sleep(200 * time.Millisecond) } log.WithContext(ctx).Infof("refreshing cache for account %s", accountID) data, err = am.refreshCache(ctx, accountID) if err != nil { return nil, err } if attempt == maxAttempts { log.WithContext(ctx).Warnf("cache for account %s reached maximum refresh attempts (%d)", accountID, maxAttempts) } } return data, nil } // isCacheFresh checks if the cache is refreshed already by comparing the accountUsers with the cache data by user count and user invite status func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers map[string]userLoggedInOnce, data []*idp.UserData) bool { userDataMap := make(map[string]*idp.UserData, len(data)) for _, datum := range data { userDataMap[datum.ID] = datum } // the accountUsers ID list of non integration users from store, we check if cache has all of them // as result of for loop knownUsersCount will have number of users are not presented in the cashed knownUsersCount := len(accountUsers) for user, loggedInOnce := range accountUsers { if datum, ok := userDataMap[user]; ok { // check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple log.WithContext(ctx).Infof("user %s has a pending invite and has logged in once, cache invalid", user) return false } knownUsersCount-- continue } log.WithContext(ctx).Debugf("cache doesn't know about %s user", user) } // if we know users that are not yet in cache more likely cache is outdated if knownUsersCount > 0 { log.WithContext(ctx).Infof("cache invalid. Users unknown to the cache: %d", knownUsersCount) return false } return true } func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accountID, userID string) error { data, err := am.getAccountFromCache(ctx, accountID, false) if err != nil { return err } for i, datum := range data { if datum.ID == userID { data = append(data[:i], data[i+1:]...) break } } return am.cacheManager.Set(am.ctx, accountID, data, cacheStore.WithExpiration(cacheEntryExpiration())) } // updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims, primaryDomain bool, ) error { if claims.Domain == "" { log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims) return nil } unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlockAccount() accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) return err } if domainIsUpToDate(accountDomain, domainCategory, claims) { return nil } user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting user: %v", err) return err } newDomain := accountDomain newCategoty := domainCategory lowerDomain := strings.ToLower(claims.Domain) if accountDomain != lowerDomain && user.HasAdminPower() { newDomain = lowerDomain } if accountDomain == lowerDomain { newCategoty = claims.DomainCategory } return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain) } // handleExistingUserAccount handles existing User accounts and update its domain attributes. // If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, // we compare the account's ID with the domain account ID, and if they don't match, we set the account as // non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain // was previously unclassified or classified as public so N users that logged int that time, has they own account // and peers that shouldn't be lost. func (am *DefaultAccountManager) handleExistingUserAccount( ctx context.Context, userAccountID string, domainAccountID string, claims jwtclaims.AuthorizationClaims, ) error { primaryDomain := domainAccountID == "" || userAccountID == domainAccountID err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain) if err != nil { return err } // we should register the account ID to this user's metadata in our IDP manager err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID) if err != nil { return err } return nil } // addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // otherwise it will create a new account and make it primary account for the domain. func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { if claims.UserId == "" { return "", fmt.Errorf("user ID is empty") } lowerDomain := strings.ToLower(claims.Domain) newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain) if err != nil { return "", err } newAccount.Domain = lowerDomain newAccount.DomainCategory = claims.DomainCategory newAccount.IsDomainPrimaryAccount = true err = am.Store.SaveAccount(ctx, newAccount) if err != nil { return "", err } err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id) if err != nil { return "", err } am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil) return newAccount.Id, nil } func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) defer unlockAccount() newUser := NewRegularUser(claims.UserId) newUser.AccountID = domainAccountID err := am.Store.SaveUser(ctx, LockingStrengthUpdate, newUser) if err != nil { return "", err } err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID) if err != nil { return "", err } am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil) return domainAccountID, nil } // redeemInvite checks whether user has been invited and redeems the invite func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error { // only possible with the enabled IdP manager if am.idpManager == nil { log.WithContext(ctx).Warnf("invites only work with enabled IdP manager") return nil } user, err := am.lookupUserInCache(ctx, userID, accountID) if err != nil { return err } if user == nil { return status.Errorf(status.NotFound, "user %s not found in the IdP", userID) } if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite { log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, accountID) // User has already logged in, meaning that IdP should have set wt_pending_invite to false. // Our job is to just reload cache. go func() { _, err = am.refreshCache(ctx, accountID) if err != nil { log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, accountID) return } log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, accountID) am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil) }() } return nil } // MarkPATUsed marks a personal access token as used func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error { return am.Store.MarkPATUsed(ctx, LockingStrengthUpdate, tokenID) } // GetAccount returns an account associated with this account ID. func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*Account, error) { return am.Store.GetAccount(ctx, accountID) } // GetAccountInfoFromPAT retrieves user, personal access token, domain, and category details from a personal access token. func (am *DefaultAccountManager) GetAccountInfoFromPAT(ctx context.Context, token string) (user *User, pat *PersonalAccessToken, domain string, category string, err error) { user, pat, err = am.extractPATFromToken(ctx, token) if err != nil { return nil, nil, "", "", err } domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, user.AccountID) if err != nil { return nil, nil, "", "", err } return user, pat, domain, category, nil } // extractPATFromToken validates the token structure and retrieves associated User and PAT. func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*User, *PersonalAccessToken, error) { if len(token) != PATLength { return nil, nil, fmt.Errorf("token has incorrect length") } prefix := token[:len(PATPrefix)] if prefix != PATPrefix { return nil, nil, fmt.Errorf("token has incorrect prefix") } secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength] encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength] verificationChecksum, err := base62.Decode(encodedChecksum) if err != nil { return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err) } secretChecksum := crc32.ChecksumIEEE([]byte(secret)) if secretChecksum != verificationChecksum { return nil, nil, fmt.Errorf("token checksum does not match") } hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) var user *User var pat *PersonalAccessToken err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { pat, err = transaction.GetPATByHashedToken(ctx, LockingStrengthShare, encodedHashedToken) if err != nil { return err } user, err = transaction.GetUserByPATID(ctx, LockingStrengthShare, pat.ID) return err }) if err != nil { return nil, nil, err } return user, pat, nil } // GetAccountByID returns an account associated with this account ID. func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } if user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") } return am.Store.GetAccount(ctx, accountID) } // GetAccountIDFromToken returns an account ID associated with this token. func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { if claims.UserId == "" { return "", "", errors.New(emptyUserID) } if am.singleAccountMode && am.singleAccountModeDomain != "" { // This section is mostly related to self-hosted installations. // We override incoming domain claims to group users under a single account. claims.Domain = am.singleAccountModeDomain claims.DomainCategory = PrivateCategory log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") } accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims) if err != nil { return "", "", err } user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) if err != nil { // this is not really possible because we got an account by user ID return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) } if user.AccountID != accountID { return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID) } if !user.IsServiceUser && claims.Invited { err = am.redeemInvite(ctx, accountID, user.Id) if err != nil { return "", "", err } } if err = am.syncJWTGroups(ctx, accountID, claims); err != nil { return "", "", err } return accountID, user.Id, nil } // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return err } if settings == nil || !settings.JWTGroupsEnabled { return nil } if settings.JWTGroupsClaimName == "" { log.WithContext(ctx).Debugf("JWT groups are enabled but no claim name is set") return nil } jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) defer func() { if unlockAccount != nil { unlockAccount() } }() var addNewGroups []string var removeOldGroups []string var hasChanges bool var user *User err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { user, err = transaction.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) if err != nil { return fmt.Errorf("error getting user: %w", err) } groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, jwtGroupsNames) if err != nil { return fmt.Errorf("error getting JWT groups changes: %w", err) } hasChanges = changed // skip update if no changes if !changed { return nil } if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { return fmt.Errorf("error saving groups: %w", err) } addNewGroups = difference(updatedAutoGroups, user.AutoGroups) removeOldGroups = difference(user.AutoGroups, updatedAutoGroups) user.AutoGroups = updatedAutoGroups if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil { return fmt.Errorf("error saving user: %w", err) } // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } groupsMap := make(map[string]*nbgroup.Group, len(groups)) for _, group := range groups { groupsMap[group.ID] = group } peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId) if err != nil { return fmt.Errorf("error getting user peers: %w", err) } updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups) if err != nil { return fmt.Errorf("error modifying user peers in groups: %w", err) } if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil { return fmt.Errorf("error saving groups: %w", err) } if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf("error incrementing network serial: %w", err) } } unlockAccount() unlockAccount = nil return nil }) if err != nil { return err } if !hasChanges { return nil } for _, g := range addNewGroups { group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { meta := map[string]any{ "group": group.Name, "group_id": group.ID, "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, } am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta) } } for _, g := range removeOldGroups { group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { meta := map[string]any{ "group": group.Name, "group_id": group.ID, "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, } am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta) } } if settings.GroupsPropagationEnabled { removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups) if err != nil { return err } newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups) if err != nil { return err } if removedGroupAffectsPeers || newGroupsAffectsPeers { log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) am.updateAccountPeers(ctx, accountID) } } return nil } // getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims. // if domain is not private or domain is invalid, it will return the account ID by user ID. // if domain is of the PrivateCategory category, it will evaluate // if account is new, existing or if there is another account with the same domain // // Use cases: // // New user + New account + New domain -> create account, user role = owner (if private domain, index domain) // // New user + New account + Existing Private Domain -> add user to the existing account, user role = user (not admin) // // New user + New account + Existing Public Domain -> create account, user role = owner // // Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain) // // Existing user + Existing account + Existing Indexed Domain -> Nothing changes // // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) if claims.UserId == "" { return "", errors.New(emptyUserID) } if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) } if claims.AccountId != "" { return am.handlePrivateAccountWithIDFromClaim(ctx, claims) } // We checked if the domain has a primary account already domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain) if cancel != nil { defer cancel() } if err != nil { return "", err } userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err } if userAccountID != "" { if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil { return "", err } return userAccountID, nil } if domainAccountID != "" { return am.addNewUserToDomainAccount(ctx, domainAccountID, claims) } return am.addNewPrivateAccount(ctx, domainAccountID, claims) } func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) return "", nil, err } if domainAccountID != "" { return domainAccountID, nil, nil } log.WithContext(ctx).Debugf("no primary account found for domain %s, acquiring global lock", domain) cancel := am.Store.AcquireGlobalLock(ctx) // check again if the domain has a primary account because of simultaneous requests domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) if handleNotFound(err) != nil { cancel() log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) return "", nil, err } return domainAccountID, cancel, nil } func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err } if userAccountID != claims.AccountId { return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) } accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) return "", err } if domainIsUpToDate(accountDomain, domainCategory, claims) { return claims.AccountId, nil } // We checked if the domain has a primary account already domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) return "", err } err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims) if err != nil { return "", err } return claims.AccountId, nil } func handleNotFound(err error) error { if err == nil { return nil } e, ok := status.FromError(err) if !ok || e.Type() != status.NotFound { return err } return nil } func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool { return domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain } func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer accountUnlock() peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) defer peerUnlock() peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err) } err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } return peer, netMap, postureChecks, nil } func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error { accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer accountUnlock() peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) defer peerUnlock() err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } return nil } func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error { accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey) if err != nil { return err } unlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer unlock() _, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID) if err != nil { return mapError(ctx, err) } return nil } // GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers() func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) { return am.peersUpdateManager.GetAllConnectedPeers(), nil } // HasConnectedChannel returns true if peers has channel in update manager, otherwise false func (am *DefaultAccountManager) HasConnectedChannel(peerID string) bool { return am.peersUpdateManager.HasChannel(peerID) } var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) func isDomainValid(domain string) bool { return invalidDomainRegexp.MatchString(domain) } // GetDNSDomain returns the configured dnsDomain func (am *DefaultAccountManager) GetDNSDomain() string { return am.dnsDomain } // CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT // group propagation and set the list of groups with access permissions. func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { accountID, _, err := am.GetAccountIDFromToken(ctx, claims) if err != nil { return err } settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return err } // Ensures JWT group synchronization to the management is enabled before, // filtering access based on the allowed groups. if settings != nil && settings.JWTGroupsEnabled { if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 { userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) if !userHasAllowedGroup(allowedGroups, userJWTGroups) { return fmt.Errorf("user does not belong to any of the allowed JWT groups") } } } return nil } func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID) am.updateAccountPeers(ctx, accountID) } func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { return am.Store.GetPostureCheckByChecksDefinition(accountID, checks) } func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) { return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey) } func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction Store, peer *nbpeer.Peer, settings *Settings) (bool, error) { user, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID) if err != nil { return false, err } err = checkIfPeerOwnerIsBlocked(peer, user) if err != nil { return false, err } if peerLoginExpired(ctx, peer, settings) { err = am.handleExpiredPeer(ctx, transaction, user, peer) if err != nil { return false, err } return true, nil } return false, nil } func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) { existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID) if err != nil { return "", fmt.Errorf("failed to get peer dns labels: %w", err) } labelMap := ConvertSliceToMap(existingLabels) newLabel, err := getPeerHostLabel(peerHostName, labelMap) if err != nil { return "", fmt.Errorf("failed to get new host label: %w", err) } if newLabel == "" { return "", fmt.Errorf("failed to get new host label: %w", err) } return newLabel, nil } func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) { return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") } return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) } // addAllGroup to account object if it doesn't exist func addAllGroup(account *Account) error { if len(account.Groups) == 0 { allGroup := &nbgroup.Group{ ID: xid.New().String(), Name: "All", Issued: nbgroup.GroupIssuedAPI, } for _, peer := range account.Peers { allGroup.Peers = append(allGroup.Peers, peer.ID) } account.Groups = map[string]*nbgroup.Group{allGroup.ID: allGroup} id := xid.New().String() defaultPolicy := &Policy{ ID: id, Name: DefaultRuleName, Description: DefaultRuleDescription, Enabled: true, Rules: []*PolicyRule{ { ID: id, Name: DefaultRuleName, Description: DefaultRuleDescription, Enabled: true, Sources: []string{allGroup.ID}, Destinations: []string{allGroup.ID}, Bidirectional: true, Protocol: PolicyRuleProtocolALL, Action: PolicyTrafficActionAccept, }, }, } account.Policies = []*Policy{defaultPolicy} } return nil } // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Account { log.WithContext(ctx).Debugf("creating new account") network := NewNetwork() peers := make(map[string]*nbpeer.Peer) users := make(map[string]*User) routes := make(map[route.ID]*route.Route) setupKeys := map[string]*SetupKey{} nameServersGroups := make(map[string]*nbdns.NameServerGroup) owner := NewOwnerUser(userID) owner.AccountID = accountID users[userID] = owner dnsSettings := DNSSettings{ DisabledManagementGroups: make([]string, 0), } log.WithContext(ctx).Debugf("created new account %s", accountID) acc := &Account{ Id: accountID, CreatedAt: time.Now().UTC(), SetupKeys: setupKeys, Network: network, Peers: peers, Users: users, CreatedBy: userID, Domain: domain, Routes: routes, NameServerGroups: nameServersGroups, DNSSettings: dnsSettings, Settings: &Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: DefaultPeerLoginExpiration, GroupsPropagationEnabled: true, RegularUsersViewBlocked: true, PeerInactivityExpirationEnabled: false, PeerInactivityExpiration: DefaultPeerInactivityExpiration, }, } if err := addAllGroup(acc); err != nil { log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err) } return acc } // extractJWTGroups extracts the group names from a JWT token's claims. func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string { userJWTGroups := make([]string, 0) if claim, ok := claims.Raw[claimName]; ok { if claimGroups, ok := claim.([]interface{}); ok { for _, g := range claimGroups { if group, ok := g.(string); ok { userJWTGroups = append(userJWTGroups, group) } else { log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g) } } } } else { log.WithContext(ctx).Debugf("JWT claim %q is not a string array", claimName) } return userJWTGroups } // userHasAllowedGroup checks if a user belongs to any of the allowed groups. func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { for _, userGroup := range userGroups { for _, allowedGroup := range allowedGroups { if userGroup == allowedGroup { return true } } } return false } // separateGroups separates user's auto groups into non-JWT and JWT groups. // Returns the list of standard auto groups and a map of JWT auto groups, // where the keys are the group names and the values are the group IDs. func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) { newAutoGroups := make([]string, 0) jwtAutoGroups := make(map[string]string) // map of group name to group ID allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups)) for _, group := range allGroups { allGroupsMap[group.ID] = group } for _, id := range autoGroups { if group, ok := allGroupsMap[id]; ok { if group.Issued == nbgroup.GroupIssuedJWT { jwtAutoGroups[group.Name] = id } else { newAutoGroups = append(newAutoGroups, id) } } } return newAutoGroups, jwtAutoGroups }