From 312bfd9bd789c10c17e9c6d068df7dabefc6618f Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 23 Apr 2025 19:36:53 +0200 Subject: [PATCH] [management] support custom domains per account (#3726) --- client/cmd/testutil_test.go | 5 ++ management/server/account.go | 21 +++++++- management/server/account/manager.go | 2 +- management/server/activity/codes.go | 4 ++ management/server/group.go | 11 +++- management/server/grpcserver.go | 44 ++++++++++----- management/server/http/api/openapi.yml | 4 ++ management/server/http/api/types.gen.go | 4 +- .../handlers/accounts/accounts_handler.go | 4 ++ .../accounts/accounts_handler_test.go | 4 ++ .../http/handlers/peers/peers_handler.go | 25 +++++++-- .../http/handlers/peers/peers_handler_test.go | 5 +- management/server/mock_server/account_mock.go | 6 +-- management/server/peer.go | 53 ++++++++++++------- management/server/types/settings.go | 4 ++ management/server/user.go | 8 ++- 16 files changed, 158 insertions(+), 46 deletions(-) diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 70abe4abe..258a8daff 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -98,6 +98,11 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc settingsMockManager := settings.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl) + settingsMockManager.EXPECT(). + GetSettings(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&types.Settings{}, nil). + AnyTimes() + accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock) if err != nil { t.Fatal(err) diff --git a/management/server/account.go b/management/server/account.go index fb0a9b65e..cc5ca309a 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -275,6 +275,10 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } + if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) { + return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) + } + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -325,6 +329,12 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco account.Network.Serial++ } + if oldSettings.DNSDomain != newSettings.DNSDomain { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, nil) + updateAccountPeers = true + account.Network.Serial++ + } + err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) if err != nil { return nil, err @@ -1493,8 +1503,15 @@ func isDomainValid(domain string) bool { } // GetDNSDomain returns the configured dnsDomain -func (am *DefaultAccountManager) GetDNSDomain() string { - return am.dnsDomain +func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string { + if settings == nil { + return am.dnsDomain + } + if settings.DNSDomain == "" { + return am.dnsDomain + } + + return settings.DNSDomain } func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { diff --git a/management/server/account/manager.go b/management/server/account/manager.go index b6eb7de05..aed83349f 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -81,7 +81,7 @@ type Manager interface { SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) - GetDNSDomain() string + GetDNSDomain(settings *types.Settings) string StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 46ae754cf..ed4be82e2 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -169,6 +169,8 @@ const ( ResourceAddedToGroup Activity = 82 ResourceRemovedFromGroup Activity = 83 + + AccountDNSDomainUpdated Activity = 84 ) var activityMap = map[Activity]Code{ @@ -264,6 +266,8 @@ var activityMap = map[Activity]Code{ ResourceAddedToGroup: {"Resource added to group", "resource.group.add"}, ResourceRemovedFromGroup: {"Resource removed from group", "resource.group.delete"}, + + AccountDNSDomainUpdated: {"Account DNS domain updated", "account.dns.domain.update"}, } // StringCode returns a string code of the activity diff --git a/management/server/group.go b/management/server/group.go index 0bd840798..87d649228 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -158,6 +158,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac return nil } + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err) + return nil + } + dnsDomain := am.GetDNSDomain(settings) + for _, peerID := range addedPeers { peer, ok := peers[peerID] if !ok { @@ -168,7 +175,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac eventsToStore = append(eventsToStore, func() { meta := map[string]any{ "group": newGroup.Name, "group_id": newGroup.ID, - "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(dnsDomain), } am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta) }) @@ -184,7 +191,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac eventsToStore = append(eventsToStore, func() { meta := map[string]any{ "group": newGroup.Name, "group_id": newGroup.ID, - "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(dnsDomain), } am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta) }) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index a7ed639c3..43d35f643 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -480,20 +480,12 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p s.ephemeralManager.OnPeerDisconnected(ctx, peer) } - var relayToken *Token - if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 { - relayToken, err = s.secretsManager.GenerateRelayToken() - if err != nil { - log.Errorf("failed generating Relay token: %v", err) - } + loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks) + if err != nil { + log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err) + return nil, status.Errorf(codes.Internal, "failed logging in peer") } - // if peer has reached this point then it has logged in - loginResp := &proto.LoginResponse{ - NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), - PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(), false), - Checks: toProtocolChecks(ctx, postureChecks), - } encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) if err != nil { log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID) @@ -506,6 +498,32 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p }, nil } +func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) { + var relayToken *Token + var err error + if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 { + relayToken, err = s.secretsManager.GenerateRelayToken() + if err != nil { + log.Errorf("failed generating Relay token: %v", err) + } + } + + settings, err := s.settingsManager.GetSettings(ctx, peer.AccountID, activity.SystemInitiator) + if err != nil { + log.WithContext(ctx).Warnf("failed getting settings for peer %s: %s", peer.Key, err) + return nil, status.Errorf(codes.Internal, "failed getting settings") + } + + // if peer has reached this point then it has logged in + loginResp := &proto.LoginResponse{ + NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), + PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), false), + Checks: toProtocolChecks(ctx, postureChecks), + } + + return loginResp, nil +} + // processJwtToken validates the existence of a JWT token in the login request, and returns the corresponding user ID if // the token is valid. // @@ -712,7 +730,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p return status.Errorf(codes.Internal, "error handling request") } - plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil, settings.RoutingPeerDNSResolutionEnabled, settings.Extra) + plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings.RoutingPeerDNSResolutionEnabled, settings.Extra) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 1717c89ac..c0ce06daa 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -112,6 +112,10 @@ components: description: Enables or disables DNS resolution on the routing peers type: boolean example: true + dns_domain: + description: Allows to define a custom dns domain for the account + type: string + example: my-organization.org extra: $ref: '#/components/schemas/AccountExtraSettings' required: diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 3fca40366..243f2fdf9 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -259,7 +259,9 @@ type AccountRequest struct { // AccountSettings defines model for AccountSettings. type AccountSettings struct { - Extra *AccountExtraSettings `json:"extra,omitempty"` + // DnsDomain Allows to define a custom dns domain for the account + DnsDomain *string `json:"dns_domain,omitempty"` + Extra *AccountExtraSettings `json:"extra,omitempty"` // GroupsPropagationEnabled Allows propagate the new user auto groups to peers that belongs to the user GroupsPropagationEnabled *bool `json:"groups_propagation_enabled,omitempty"` diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index c0851102f..7cad26bd6 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -119,6 +119,9 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { if req.Settings.RoutingPeerDnsResolutionEnabled != nil { settings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled } + if req.Settings.DnsDomain != nil { + settings.DNSDomain = *req.Settings.DnsDomain + } updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) if err != nil { @@ -178,6 +181,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A JwtAllowGroups: &jwtAllowGroups, RegularUsersViewBlocked: settings.RegularUsersViewBlocked, RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled, + DnsDomain: &settings.DNSDomain, } if settings.Extra != nil { diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 2acca4f49..57bbffc7c 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -108,6 +108,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtAllowGroups: &[]string{}, RegularUsersViewBlocked: true, RoutingPeerDnsResolutionEnabled: br(false), + DnsDomain: sr(""), }, expectedArray: true, expectedID: accountID, @@ -128,6 +129,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtAllowGroups: &[]string{}, RegularUsersViewBlocked: false, RoutingPeerDnsResolutionEnabled: br(false), + DnsDomain: sr(""), }, expectedArray: false, expectedID: accountID, @@ -148,6 +150,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtAllowGroups: &[]string{"test"}, RegularUsersViewBlocked: true, RoutingPeerDnsResolutionEnabled: br(false), + DnsDomain: sr(""), }, expectedArray: false, expectedID: accountID, @@ -168,6 +171,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtAllowGroups: &[]string{}, RegularUsersViewBlocked: true, RoutingPeerDnsResolutionEnabled: br(false), + DnsDomain: sr(""), }, expectedArray: false, expectedID: accountID, diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index fa78836d8..58ea06ea3 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -65,7 +65,13 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, util.WriteError(ctx, err, w) return } - dnsDomain := h.accountManager.GetDNSDomain() + settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator) + if err != nil { + util.WriteError(ctx, err, w) + return + } + + dnsDomain := h.accountManager.GetDNSDomain(settings) grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) @@ -110,7 +116,13 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri util.WriteError(ctx, err, w) return } - dnsDomain := h.accountManager.GetDNSDomain() + + settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator) + if err != nil { + util.WriteError(ctx, err, w) + return + } + dnsDomain := h.accountManager.GetDNSDomain(settings) peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) if err != nil { @@ -192,7 +204,12 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { return } - dnsDomain := h.accountManager.GetDNSDomain() + settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, activity.SystemInitiator) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + dnsDomain := h.accountManager.GetDNSDomain(settings) grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) @@ -279,7 +296,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { return } - dnsDomain := h.accountManager.GetDNSDomain() + dnsDomain := h.accountManager.GetDNSDomain(account.Settings) customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index a03c3c29d..a1fc13dd3 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -152,7 +152,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { }, }, nil }, - GetDNSDomainFunc: func() string { + GetDNSDomainFunc: func(settings *types.Settings) string { return "netbird.selfhosted" }, GetAccountFunc: func(ctx context.Context, accountID string) (*types.Account, error) { @@ -172,6 +172,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { _, ok := statuses[peerID] return ok }, + GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) { + return account.Settings, nil + }, }, } } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 804877a66..2b57e6888 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -83,7 +83,7 @@ type MockAccountManager struct { CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) DeleteAccountFunc func(ctx context.Context, accountID, userID string) error - GetDNSDomainFunc func() string + GetDNSDomainFunc func(settings *types.Settings) string StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error) GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error) @@ -620,9 +620,9 @@ func (am *MockAccountManager) GetPeers(ctx context.Context, accountID, userID, n } // GetDNSDomain mocks GetDNSDomain of the AccountManager interface -func (am *MockAccountManager) GetDNSDomain() string { +func (am *MockAccountManager) GetDNSDomain(settings *types.Settings) string { if am.GetDNSDomainFunc != nil { - return am.GetDNSDomainFunc() + return am.GetDNSDomainFunc(settings) } return "" } diff --git a/management/server/peer.go b/management/server/peer.go index 27825a148..908610fbe 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -206,6 +206,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user var sshChanged bool var loginExpirationChanged bool var inactivityExpirationChanged bool + var dnsDomain string err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, update.ID) @@ -223,7 +224,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return err } - update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra) + dnsDomain = am.GetDNSDomain(settings) + + update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra) if err != nil { return err } @@ -276,11 +279,11 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if !peer.SSHEnabled { event = activity.PeerSSHDisabled } - am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain)) } if peerLabelChanged { - am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(dnsDomain)) } if loginExpirationChanged { @@ -288,7 +291,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if !peer.LoginExpirationEnabled { event = activity.PeerLoginExpirationDisabled } - am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain)) if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { am.checkAndSchedulePeerLoginExpiration(ctx, accountID) @@ -300,7 +303,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if !peer.InactivityExpirationEnabled { event = activity.PeerInactivityExpirationDisabled } - am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain)) if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) @@ -413,7 +416,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin if err != nil { return nil, err } - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id) if err != nil { @@ -574,8 +577,13 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s ExtraDNSLabels: peer.ExtraDNSLabels, AllowExtraDNSLabels: allowExtraDNSLabels, } + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return fmt.Errorf("failed to get account settings: %w", err) + } + opEvent.TargetID = newPeer.ID - opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) + opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings)) if !addedByUser { opEvent.Meta["setup_key_name"] = setupKeyName } @@ -591,10 +599,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } } - settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return fmt.Errorf("failed to get account settings: %w", err) - } newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) err = transaction.AddPeerToAllGroup(ctx, store.LockingStrengthUpdate, accountID, newPeer.ID) @@ -1024,7 +1028,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id) if err != nil { @@ -1060,7 +1064,12 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact log.WithContext(ctx).Debugf("failed to update user last login: %v", err) } - am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, peer.AccountID) + if err != nil { + return fmt.Errorf("failed to get account settings: %w", err) + } + + am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain(settings))) return nil } @@ -1174,7 +1183,8 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account semaphore := make(chan struct{}, 10) dnsCache := &DNSConfigCache{} - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + dnsDomain := am.GetDNSDomain(account.Settings) + customZone := account.GetPeersCustomZone(ctx, dnsDomain) resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() @@ -1215,7 +1225,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account return } - update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting) + update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSetting) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) }(peer) } @@ -1270,7 +1280,8 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI } dnsCache := &DNSConfigCache{} - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + dnsDomain := am.GetDNSDomain(account.Settings) + customZone := account.GetPeersCustomZone(ctx, dnsDomain) resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() @@ -1299,7 +1310,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI return } - update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings) + update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled, extraSettings) am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) } @@ -1484,6 +1495,12 @@ func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { var peerDeletedEvents []func() + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + dnsDomain := am.GetDNSDomain(settings) + for _, peer := range peers { if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID); err != nil { return nil, err @@ -1514,7 +1531,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) peerDeletedEvents = append(peerDeletedEvents, func() { - am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) }) } diff --git a/management/server/types/settings.go b/management/server/types/settings.go index 7054ede8c..c8de2a98c 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -39,6 +39,9 @@ type Settings struct { // RoutingPeerDNSResolutionEnabled enabled the DNS resolution on the routing peers RoutingPeerDNSResolutionEnabled bool + // DNSDomain is the custom domain for that account + DNSDomain string + // Extra is a dictionary of Account settings Extra *ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"` } @@ -58,6 +61,7 @@ func (s *Settings) Copy() *Settings { PeerInactivityExpiration: s.PeerInactivityExpiration, RoutingPeerDNSResolutionEnabled: s.RoutingPeerDNSResolutionEnabled, + DNSDomain: s.DNSDomain, } if s.Extra != nil { settings.Extra = s.Extra.Copy() diff --git a/management/server/user.go b/management/server/user.go index 9ec16e72c..b46ed24cf 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -940,6 +940,12 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a // expireAndUpdatePeers expires all peers of the given user and updates them in the account func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error { + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return err + } + dnsDomain := am.GetDNSDomain(settings) + var peerIDs []string for _, peer := range peers { // nolint:staticcheck @@ -957,7 +963,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou am.StoreEvent( ctx, peer.UserID, peer.ID, accountID, - activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()), + activity.PeerLoginExpired, peer.EventMeta(dnsDomain), ) }