[management] support custom domains per account (#3726)

This commit is contained in:
Pascal Fischer 2025-04-23 19:36:53 +02:00 committed by GitHub
parent 8db05838ca
commit 312bfd9bd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 158 additions and 46 deletions

View File

@ -98,6 +98,11 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock) accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -275,6 +275,10 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
} }
if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) {
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
}
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
@ -325,6 +329,12 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
account.Network.Serial++ account.Network.Serial++
} }
if oldSettings.DNSDomain != newSettings.DNSDomain {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, nil)
updateAccountPeers = true
account.Network.Serial++
}
err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -1493,8 +1503,15 @@ func isDomainValid(domain string) bool {
} }
// GetDNSDomain returns the configured dnsDomain // GetDNSDomain returns the configured dnsDomain
func (am *DefaultAccountManager) GetDNSDomain() string { func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string {
if settings == nil {
return am.dnsDomain return am.dnsDomain
}
if settings.DNSDomain == "" {
return am.dnsDomain
}
return settings.DNSDomain
} }
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {

View File

@ -81,7 +81,7 @@ type Manager interface {
SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
GetDNSDomain() string GetDNSDomain(settings *types.Settings) string
StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error)
GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error)

View File

@ -169,6 +169,8 @@ const (
ResourceAddedToGroup Activity = 82 ResourceAddedToGroup Activity = 82
ResourceRemovedFromGroup Activity = 83 ResourceRemovedFromGroup Activity = 83
AccountDNSDomainUpdated Activity = 84
) )
var activityMap = map[Activity]Code{ var activityMap = map[Activity]Code{
@ -264,6 +266,8 @@ var activityMap = map[Activity]Code{
ResourceAddedToGroup: {"Resource added to group", "resource.group.add"}, ResourceAddedToGroup: {"Resource added to group", "resource.group.add"},
ResourceRemovedFromGroup: {"Resource removed from group", "resource.group.delete"}, ResourceRemovedFromGroup: {"Resource removed from group", "resource.group.delete"},
AccountDNSDomainUpdated: {"Account DNS domain updated", "account.dns.domain.update"},
} }
// StringCode returns a string code of the activity // StringCode returns a string code of the activity

View File

@ -158,6 +158,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
return nil return nil
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err)
return nil
}
dnsDomain := am.GetDNSDomain(settings)
for _, peerID := range addedPeers { for _, peerID := range addedPeers {
peer, ok := peers[peerID] peer, ok := peers[peerID]
if !ok { if !ok {
@ -168,7 +175,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
eventsToStore = append(eventsToStore, func() { eventsToStore = append(eventsToStore, func() {
meta := map[string]any{ meta := map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "group": newGroup.Name, "group_id": newGroup.ID,
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(dnsDomain),
} }
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta) am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta)
}) })
@ -184,7 +191,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
eventsToStore = append(eventsToStore, func() { eventsToStore = append(eventsToStore, func() {
meta := map[string]any{ meta := map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "group": newGroup.Name, "group_id": newGroup.ID,
"peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(dnsDomain),
} }
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta) am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta)
}) })

View File

@ -480,20 +480,12 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
s.ephemeralManager.OnPeerDisconnected(ctx, peer) s.ephemeralManager.OnPeerDisconnected(ctx, peer)
} }
var relayToken *Token loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
relayToken, err = s.secretsManager.GenerateRelayToken()
if err != nil { if err != nil {
log.Errorf("failed generating Relay token: %v", err) log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
} return nil, status.Errorf(codes.Internal, "failed logging in peer")
} }
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(), false),
Checks: toProtocolChecks(ctx, postureChecks),
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID) log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
@ -506,6 +498,32 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
}, nil }, nil
} }
func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
var relayToken *Token
var err error
if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
relayToken, err = s.secretsManager.GenerateRelayToken()
if err != nil {
log.Errorf("failed generating Relay token: %v", err)
}
}
settings, err := s.settingsManager.GetSettings(ctx, peer.AccountID, activity.SystemInitiator)
if err != nil {
log.WithContext(ctx).Warnf("failed getting settings for peer %s: %s", peer.Key, err)
return nil, status.Errorf(codes.Internal, "failed getting settings")
}
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), false),
Checks: toProtocolChecks(ctx, postureChecks),
}
return loginResp, nil
}
// processJwtToken validates the existence of a JWT token in the login request, and returns the corresponding user ID if // processJwtToken validates the existence of a JWT token in the login request, and returns the corresponding user ID if
// the token is valid. // the token is valid.
// //
@ -712,7 +730,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
return status.Errorf(codes.Internal, "error handling request") return status.Errorf(codes.Internal, "error handling request")
} }
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil, settings.RoutingPeerDNSResolutionEnabled, settings.Extra) plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings.RoutingPeerDNSResolutionEnabled, settings.Extra)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil { if err != nil {

View File

@ -112,6 +112,10 @@ components:
description: Enables or disables DNS resolution on the routing peers description: Enables or disables DNS resolution on the routing peers
type: boolean type: boolean
example: true example: true
dns_domain:
description: Allows to define a custom dns domain for the account
type: string
example: my-organization.org
extra: extra:
$ref: '#/components/schemas/AccountExtraSettings' $ref: '#/components/schemas/AccountExtraSettings'
required: required:

View File

@ -259,6 +259,8 @@ type AccountRequest struct {
// AccountSettings defines model for AccountSettings. // AccountSettings defines model for AccountSettings.
type AccountSettings struct { type AccountSettings struct {
// DnsDomain Allows to define a custom dns domain for the account
DnsDomain *string `json:"dns_domain,omitempty"`
Extra *AccountExtraSettings `json:"extra,omitempty"` Extra *AccountExtraSettings `json:"extra,omitempty"`
// GroupsPropagationEnabled Allows propagate the new user auto groups to peers that belongs to the user // GroupsPropagationEnabled Allows propagate the new user auto groups to peers that belongs to the user

View File

@ -119,6 +119,9 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
if req.Settings.RoutingPeerDnsResolutionEnabled != nil { if req.Settings.RoutingPeerDnsResolutionEnabled != nil {
settings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled settings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled
} }
if req.Settings.DnsDomain != nil {
settings.DNSDomain = *req.Settings.DnsDomain
}
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
if err != nil { if err != nil {
@ -178,6 +181,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
JwtAllowGroups: &jwtAllowGroups, JwtAllowGroups: &jwtAllowGroups,
RegularUsersViewBlocked: settings.RegularUsersViewBlocked, RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled, RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled,
DnsDomain: &settings.DNSDomain,
} }
if settings.Extra != nil { if settings.Extra != nil {

View File

@ -108,6 +108,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
JwtAllowGroups: &[]string{}, JwtAllowGroups: &[]string{},
RegularUsersViewBlocked: true, RegularUsersViewBlocked: true,
RoutingPeerDnsResolutionEnabled: br(false), RoutingPeerDnsResolutionEnabled: br(false),
DnsDomain: sr(""),
}, },
expectedArray: true, expectedArray: true,
expectedID: accountID, expectedID: accountID,
@ -128,6 +129,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
JwtAllowGroups: &[]string{}, JwtAllowGroups: &[]string{},
RegularUsersViewBlocked: false, RegularUsersViewBlocked: false,
RoutingPeerDnsResolutionEnabled: br(false), RoutingPeerDnsResolutionEnabled: br(false),
DnsDomain: sr(""),
}, },
expectedArray: false, expectedArray: false,
expectedID: accountID, expectedID: accountID,
@ -148,6 +150,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
JwtAllowGroups: &[]string{"test"}, JwtAllowGroups: &[]string{"test"},
RegularUsersViewBlocked: true, RegularUsersViewBlocked: true,
RoutingPeerDnsResolutionEnabled: br(false), RoutingPeerDnsResolutionEnabled: br(false),
DnsDomain: sr(""),
}, },
expectedArray: false, expectedArray: false,
expectedID: accountID, expectedID: accountID,
@ -168,6 +171,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
JwtAllowGroups: &[]string{}, JwtAllowGroups: &[]string{},
RegularUsersViewBlocked: true, RegularUsersViewBlocked: true,
RoutingPeerDnsResolutionEnabled: br(false), RoutingPeerDnsResolutionEnabled: br(false),
DnsDomain: sr(""),
}, },
expectedArray: false, expectedArray: false,
expectedID: accountID, expectedID: accountID,

View File

@ -65,7 +65,13 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
util.WriteError(ctx, err, w) util.WriteError(ctx, err, w)
return return
} }
dnsDomain := h.accountManager.GetDNSDomain() settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
if err != nil {
util.WriteError(ctx, err, w)
return
}
dnsDomain := h.accountManager.GetDNSDomain(settings)
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
@ -110,7 +116,13 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
util.WriteError(ctx, err, w) util.WriteError(ctx, err, w)
return return
} }
dnsDomain := h.accountManager.GetDNSDomain()
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
if err != nil {
util.WriteError(ctx, err, w)
return
}
dnsDomain := h.accountManager.GetDNSDomain(settings)
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
if err != nil { if err != nil {
@ -192,7 +204,12 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
return return
} }
dnsDomain := h.accountManager.GetDNSDomain() settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, activity.SystemInitiator)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
dnsDomain := h.accountManager.GetDNSDomain(settings)
grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
@ -279,7 +296,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
return return
} }
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)

View File

@ -152,7 +152,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
}, },
}, nil }, nil
}, },
GetDNSDomainFunc: func() string { GetDNSDomainFunc: func(settings *types.Settings) string {
return "netbird.selfhosted" return "netbird.selfhosted"
}, },
GetAccountFunc: func(ctx context.Context, accountID string) (*types.Account, error) { GetAccountFunc: func(ctx context.Context, accountID string) (*types.Account, error) {
@ -172,6 +172,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
_, ok := statuses[peerID] _, ok := statuses[peerID]
return ok return ok
}, },
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
return account.Settings, nil
},
}, },
} }
} }

View File

@ -83,7 +83,7 @@ type MockAccountManager struct {
CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error)
GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
GetDNSDomainFunc func() string GetDNSDomainFunc func(settings *types.Settings) string
StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error) GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error)
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error) GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error)
@ -620,9 +620,9 @@ func (am *MockAccountManager) GetPeers(ctx context.Context, accountID, userID, n
} }
// GetDNSDomain mocks GetDNSDomain of the AccountManager interface // GetDNSDomain mocks GetDNSDomain of the AccountManager interface
func (am *MockAccountManager) GetDNSDomain() string { func (am *MockAccountManager) GetDNSDomain(settings *types.Settings) string {
if am.GetDNSDomainFunc != nil { if am.GetDNSDomainFunc != nil {
return am.GetDNSDomainFunc() return am.GetDNSDomainFunc(settings)
} }
return "" return ""
} }

View File

@ -206,6 +206,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
var sshChanged bool var sshChanged bool
var loginExpirationChanged bool var loginExpirationChanged bool
var inactivityExpirationChanged bool var inactivityExpirationChanged bool
var dnsDomain string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, update.ID) peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, update.ID)
@ -223,7 +224,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return err return err
} }
update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra) dnsDomain = am.GetDNSDomain(settings)
update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra)
if err != nil { if err != nil {
return err return err
} }
@ -276,11 +279,11 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
if !peer.SSHEnabled { if !peer.SSHEnabled {
event = activity.PeerSSHDisabled event = activity.PeerSSHDisabled
} }
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain))
} }
if peerLabelChanged { if peerLabelChanged {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain())) am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(dnsDomain))
} }
if loginExpirationChanged { if loginExpirationChanged {
@ -288,7 +291,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
if !peer.LoginExpirationEnabled { if !peer.LoginExpirationEnabled {
event = activity.PeerLoginExpirationDisabled event = activity.PeerLoginExpirationDisabled
} }
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain))
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, accountID) am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
@ -300,7 +303,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
if !peer.InactivityExpirationEnabled { if !peer.InactivityExpirationEnabled {
event = activity.PeerInactivityExpirationDisabled event = activity.PeerInactivityExpirationDisabled
} }
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain))
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
@ -413,7 +416,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
if err != nil { if err != nil {
return nil, err return nil, err
} }
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id) proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id)
if err != nil { if err != nil {
@ -574,8 +577,13 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
ExtraDNSLabels: peer.ExtraDNSLabels, ExtraDNSLabels: peer.ExtraDNSLabels,
AllowExtraDNSLabels: allowExtraDNSLabels, AllowExtraDNSLabels: allowExtraDNSLabels,
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("failed to get account settings: %w", err)
}
opEvent.TargetID = newPeer.ID opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings))
if !addedByUser { if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName opEvent.Meta["setup_key_name"] = setupKeyName
} }
@ -591,10 +599,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("failed to get account settings: %w", err)
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID) err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID)
@ -1024,7 +1028,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return nil, nil, nil, err return nil, nil, nil, err
} }
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id) proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id)
if err != nil { if err != nil {
@ -1060,7 +1064,12 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
log.WithContext(ctx).Debugf("failed to update user last login: %v", err) log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
} }
am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, peer.AccountID)
if err != nil {
return fmt.Errorf("failed to get account settings: %w", err)
}
am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain(settings)))
return nil return nil
} }
@ -1174,7 +1183,8 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
semaphore := make(chan struct{}, 10) semaphore := make(chan struct{}, 10)
dnsCache := &DNSConfigCache{} dnsCache := &DNSConfigCache{}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) dnsDomain := am.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap() resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
@ -1215,7 +1225,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
return return
} }
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
}(peer) }(peer)
} }
@ -1270,7 +1280,8 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
} }
dnsCache := &DNSConfigCache{} dnsCache := &DNSConfigCache{}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) dnsDomain := am.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap() resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
@ -1299,7 +1310,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
return return
} }
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings) update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings)
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
} }
@ -1484,6 +1495,12 @@ func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
var peerDeletedEvents []func() var peerDeletedEvents []func()
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
dnsDomain := am.GetDNSDomain(settings)
for _, peer := range peers { for _, peer := range peers {
if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID); err != nil { if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID); err != nil {
return nil, err return nil, err
@ -1514,7 +1531,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
}) })
am.peersUpdateManager.CloseChannel(ctx, peer.ID) am.peersUpdateManager.CloseChannel(ctx, peer.ID)
peerDeletedEvents = append(peerDeletedEvents, func() { peerDeletedEvents = append(peerDeletedEvents, func() {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
}) })
} }

View File

@ -39,6 +39,9 @@ type Settings struct {
// RoutingPeerDNSResolutionEnabled enabled the DNS resolution on the routing peers // RoutingPeerDNSResolutionEnabled enabled the DNS resolution on the routing peers
RoutingPeerDNSResolutionEnabled bool RoutingPeerDNSResolutionEnabled bool
// DNSDomain is the custom domain for that account
DNSDomain string
// Extra is a dictionary of Account settings // Extra is a dictionary of Account settings
Extra *ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"` Extra *ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"`
} }
@ -58,6 +61,7 @@ func (s *Settings) Copy() *Settings {
PeerInactivityExpiration: s.PeerInactivityExpiration, PeerInactivityExpiration: s.PeerInactivityExpiration,
RoutingPeerDNSResolutionEnabled: s.RoutingPeerDNSResolutionEnabled, RoutingPeerDNSResolutionEnabled: s.RoutingPeerDNSResolutionEnabled,
DNSDomain: s.DNSDomain,
} }
if s.Extra != nil { if s.Extra != nil {
settings.Extra = s.Extra.Copy() settings.Extra = s.Extra.Copy()

View File

@ -940,6 +940,12 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
// expireAndUpdatePeers expires all peers of the given user and updates them in the account // expireAndUpdatePeers expires all peers of the given user and updates them in the account
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error { func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error {
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return err
}
dnsDomain := am.GetDNSDomain(settings)
var peerIDs []string var peerIDs []string
for _, peer := range peers { for _, peer := range peers {
// nolint:staticcheck // nolint:staticcheck
@ -957,7 +963,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
am.StoreEvent( am.StoreEvent(
ctx, ctx,
peer.UserID, peer.ID, accountID, peer.UserID, peer.ID, accountID,
activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()), activity.PeerLoginExpired, peer.EventMeta(dnsDomain),
) )
} }