diff --git a/management/server/management_refactor/server/access_control/access_control_manager.go b/management/server/management_refactor/server/access_control/access_control_manager.go new file mode 100644 index 000000000..60a9ef106 --- /dev/null +++ b/management/server/management_refactor/server/access_control/access_control_manager.go @@ -0,0 +1,7 @@ +package access_control + +type AccessControlManager interface { +} + +type DefaultAccessControlManager struct { +} diff --git a/management/server/management_refactor/server/access_control/policy.go b/management/server/management_refactor/server/access_control/policy.go new file mode 100644 index 000000000..690deb966 --- /dev/null +++ b/management/server/management_refactor/server/access_control/policy.go @@ -0,0 +1 @@ +package access_control diff --git a/management/server/management_refactor/server/access_control/rule.go b/management/server/management_refactor/server/access_control/rule.go new file mode 100644 index 000000000..a23d1a171 --- /dev/null +++ b/management/server/management_refactor/server/access_control/rule.go @@ -0,0 +1,100 @@ +package access_control + +import "fmt" + +// TrafficFlowType defines allowed direction of the traffic in the rule +type TrafficFlowType int + +const ( + // TrafficFlowBidirect allows traffic to both direction + TrafficFlowBidirect TrafficFlowType = iota + // TrafficFlowBidirectString allows traffic to both direction + TrafficFlowBidirectString = "bidirect" + // DefaultRuleName is a name for the Default rule that is created for every account + DefaultRuleName = "Default" + // DefaultRuleDescription is a description for the Default rule that is created for every account + DefaultRuleDescription = "This is a default rule that allows connections between all the resources" + // DefaultPolicyName is a name for the Default policy that is created for every account + DefaultPolicyName = "Default" + // DefaultPolicyDescription is a description for the Default policy that is created for every account + DefaultPolicyDescription = "This is a default policy that allows connections between all the resources" +) + +// Rule of ACL for groups +type Rule struct { + // ID of the rule + ID string + + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + + // Name of the rule visible in the UI + Name string + + // Description of the rule visible in the UI + Description string + + // Disabled status of rule in the system + Disabled bool + + // Source list of groups IDs of peers + Source []string `gorm:"serializer:json"` + + // Destination list of groups IDs of peers + Destination []string `gorm:"serializer:json"` + + // Flow of the traffic allowed by the rule + Flow TrafficFlowType +} + +func (r *Rule) Copy() *Rule { + rule := &Rule{ + ID: r.ID, + Name: r.Name, + Description: r.Description, + Disabled: r.Disabled, + Source: make([]string, len(r.Source)), + Destination: make([]string, len(r.Destination)), + Flow: r.Flow, + } + copy(rule.Source, r.Source) + copy(rule.Destination, r.Destination) + return rule +} + +// EventMeta returns activity event meta related to this rule +func (r *Rule) EventMeta() map[string]any { + return map[string]any{"name": r.Name} +} + +// ToPolicyRule converts a Rule to a PolicyRule object +func (r *Rule) ToPolicyRule() *PolicyRule { + if r == nil { + return nil + } + return &PolicyRule{ + ID: r.ID, + Name: r.Name, + Enabled: !r.Disabled, + Description: r.Description, + Destinations: r.Destination, + Sources: r.Source, + Bidirectional: true, + Protocol: PolicyRuleProtocolALL, + Action: PolicyTrafficActionAccept, + } +} + +// RuleToPolicy converts a Rule to a Policy query object +func RuleToPolicy(rule *Rule) (*Policy, error) { + if rule == nil { + return nil, fmt.Errorf("rule is empty") + } + return &Policy{ + ID: rule.ID, + Name: rule.Name, + Description: rule.Description, + Enabled: !rule.Disabled, + Rules: []*PolicyRule{rule.ToPolicyRule()}, + }, nil +} diff --git a/management/server/management_refactor/server/accounts/account.go b/management/server/management_refactor/server/accounts/account.go new file mode 100644 index 000000000..7f0640ed4 --- /dev/null +++ b/management/server/management_refactor/server/accounts/account.go @@ -0,0 +1,69 @@ +package accounts + +import ( + "time" + + nbdns "github.com/netbirdio/netbird/dns" +) + +// Settings represents Account settings structure that can be modified via API and Dashboard +type Settings struct { + // PeerLoginExpirationEnabled globally enables or disables peer login expiration + PeerLoginExpirationEnabled bool + + // PeerLoginExpiration is a setting that indicates when peer login expires. + // Applies to all peers that have Peer.LoginExpirationEnabled set to true. + PeerLoginExpiration time.Duration + + // GroupsPropagationEnabled allows to propagate auto groups from the user to the peer + GroupsPropagationEnabled bool + + // JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName + // and add it to account groups. + JWTGroupsEnabled bool + + // JWTGroupsClaimName from which we extract groups name to add it to account groups + JWTGroupsClaimName string +} + +// Copy copies the Settings struct +func (s *Settings) Copy() *Settings { + return &Settings{ + PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled, + PeerLoginExpiration: s.PeerLoginExpiration, + JWTGroupsEnabled: s.JWTGroupsEnabled, + JWTGroupsClaimName: s.JWTGroupsClaimName, + GroupsPropagationEnabled: s.GroupsPropagationEnabled, + } +} + +// Account represents a unique account of the system +type Account struct { + // we have to name column to aid as it collides with Network.Id when work with associations + Id string `gorm:"primaryKey"` + + // User.Id it was created by + CreatedBy string + Domain string `gorm:"index"` + DomainCategory string + IsDomainPrimaryAccount bool + SetupKeys map[string]*SetupKey `gorm:"-"` + SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"` + Network *Network `gorm:"embedded;embeddedPrefix:network_"` + Peers map[string]*Peer `gorm:"-"` + PeersG []Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` + Users map[string]*User `gorm:"-"` + UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` + Groups map[string]*Group `gorm:"-"` + GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"` + Rules map[string]*Rule `gorm:"-"` + RulesG []Rule `json:"-" gorm:"foreignKey:AccountID;references:id"` + Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` + Routes map[string]*route.Route `gorm:"-"` + RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` + NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"` + NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"` + DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` + // Settings is a dictionary of Account settings + Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` +} diff --git a/management/server/management_refactor/server/accounts/account_manager.go b/management/server/management_refactor/server/accounts/account_manager.go new file mode 100644 index 000000000..3b4605965 --- /dev/null +++ b/management/server/management_refactor/server/accounts/account_manager.go @@ -0,0 +1,22 @@ +package accounts + +type AccountManager interface { + GetAccount(accountID string) (Account, error) + GetDNSDomain() string +} + +type DefaultAccountManager struct { + repository AccountRepository + + // dnsDomain is used for peer resolution. This is appended to the peer's name + dnsDomain string +} + +func (am *DefaultAccountManager) GetAccount(accountID string) (Account, error) { + return am.repository.findAccountByID(accountID) +} + +// GetDNSDomain returns the configured dnsDomain +func (am *DefaultAccountManager) GetDNSDomain() string { + return am.dnsDomain +} diff --git a/management/server/management_refactor/server/accounts/account_repository.go b/management/server/management_refactor/server/accounts/account_repository.go new file mode 100644 index 000000000..6489c8fa5 --- /dev/null +++ b/management/server/management_refactor/server/accounts/account_repository.go @@ -0,0 +1,5 @@ +package accounts + +type AccountRepository interface { + findAccountByID(accountID string) (Account, error) +} diff --git a/management/server/management_refactor/server/config.go b/management/server/management_refactor/server/config.go new file mode 100644 index 000000000..4fed93bba --- /dev/null +++ b/management/server/management_refactor/server/config.go @@ -0,0 +1,150 @@ +package server + +import ( + "net/url" + + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/util" +) + +type ( + // Protocol type + Protocol string + + // Provider authorization flow type + Provider string +) + +const ( + UDP Protocol = "udp" + DTLS Protocol = "dtls" + TCP Protocol = "tcp" + HTTP Protocol = "http" + HTTPS Protocol = "https" + NONE Provider = "none" +) + +const ( + // DefaultDeviceAuthFlowScope defines the bare minimum scope to request in the device authorization flow + DefaultDeviceAuthFlowScope string = "openid" +) + +// Config of the Management service +type Config struct { + Stuns []*Host + TURNConfig *TURNConfig + Signal *Host + + Datadir string + DataStoreEncryptionKey string + + HttpConfig *HttpServerConfig + + IdpManagerConfig *idp.Config + + DeviceAuthorizationFlow *DeviceAuthorizationFlow + + PKCEAuthorizationFlow *PKCEAuthorizationFlow + + StoreConfig StoreConfig +} + +// GetAuthAudiences returns the audience from the http config and device authorization flow config +func (c Config) GetAuthAudiences() []string { + audiences := []string{c.HttpConfig.AuthAudience} + + if c.DeviceAuthorizationFlow != nil && c.DeviceAuthorizationFlow.ProviderConfig.Audience != "" { + audiences = append(audiences, c.DeviceAuthorizationFlow.ProviderConfig.Audience) + } + + return audiences +} + +// TURNConfig is a config of the TURNCredentialsManager +type TURNConfig struct { + TimeBasedCredentials bool + CredentialsTTL util.Duration + Secret string + Turns []*Host +} + +// HttpServerConfig is a config of the HTTP Management service server +type HttpServerConfig struct { + LetsEncryptDomain string + // CertFile is the location of the certificate + CertFile string + // CertKey is the location of the certificate private key + CertKey string + // AuthAudience identifies the recipients that the JWT is intended for (aud in JWT) + AuthAudience string + // AuthIssuer identifies principal that issued the JWT + AuthIssuer string + // AuthUserIDClaim is the name of the claim that used as user ID + AuthUserIDClaim string + // AuthKeysLocation is a location of JWT key set containing the public keys used to verify JWT + AuthKeysLocation string + // OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration + OIDCConfigEndpoint string + // IdpSignKeyRefreshEnabled identifies the signing key is currently being rotated or not + IdpSignKeyRefreshEnabled bool +} + +// Host represents a Wiretrustee host (e.g. STUN, TURN, Signal) +type Host struct { + Proto Protocol + // URI e.g. turns://stun.wiretrustee.com:4430 or signal.wiretrustee.com:10000 + URI string + Username string + Password string +} + +// DeviceAuthorizationFlow represents Device Authorization Flow information +// that can be used by the client to login initiate a Oauth 2.0 device authorization grant flow +// see https://datatracker.ietf.org/doc/html/rfc8628 +type DeviceAuthorizationFlow struct { + Provider string + ProviderConfig ProviderConfig +} + +// PKCEAuthorizationFlow represents Authorization Code Flow information +// that can be used by the client to login initiate a Oauth 2.0 authorization code grant flow +// with Proof Key for Code Exchange (PKCE). See https://datatracker.ietf.org/doc/html/rfc7636 +type PKCEAuthorizationFlow struct { + ProviderConfig ProviderConfig +} + +// ProviderConfig has all attributes needed to initiate a device/pkce authorization flow +type ProviderConfig struct { + // ClientID An IDP application client id + ClientID string + // ClientSecret An IDP application client secret + ClientSecret string + // Domain An IDP API domain + // Deprecated. Use TokenEndpoint and DeviceAuthEndpoint + Domain string + // Audience An Audience for to authorization validation + Audience string + // TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token + TokenEndpoint string + // DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code + DeviceAuthEndpoint string + // AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code + AuthorizationEndpoint string + // Scopes provides the scopes to be included in the token request + Scope string + // UseIDToken indicates if the id token should be used for authentication + UseIDToken bool + // RedirectURL handles authorization code from IDP manager + RedirectURLs []string +} + +// StoreConfig contains Store configuration +type StoreConfig struct { + Engine StoreEngine +} + +// validateURL validates input http url +func validateURL(httpURL string) bool { + _, err := url.ParseRequestURI(httpURL) + return err == nil +} diff --git a/management/server/management_refactor/server/dns/dns.go b/management/server/management_refactor/server/dns/dns.go new file mode 100644 index 000000000..1ffe03d57 --- /dev/null +++ b/management/server/management_refactor/server/dns/dns.go @@ -0,0 +1 @@ +package dns diff --git a/management/server/management_refactor/server/events/events_manager.go b/management/server/management_refactor/server/events/events_manager.go new file mode 100644 index 000000000..5d0c3ac9d --- /dev/null +++ b/management/server/management_refactor/server/events/events_manager.go @@ -0,0 +1,67 @@ +package events + +import ( + "fmt" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/activity" +) + +type EventsManager struct { + store activity.Store +} + +func NewEventsManager(store activity.Store) *EventsManager { + return &EventsManager{ + store: store, + } +} + +// GetEvents returns a list of activity events of an account +func (em *EventsManager) GetEvents(accountID, userID string) ([]*activity.Event, error) { + events, err := em.store.Get(accountID, 0, 10000, true) + if err != nil { + return nil, err + } + + // this is a workaround for duplicate activity.UserJoined events that might occur when a user redeems invite. + // we will need to find a better way to handle this. + filtered := make([]*activity.Event, 0) + dups := make(map[string]struct{}) + for _, event := range events { + if event.Activity == activity.UserJoined { + key := event.TargetID + event.InitiatorID + event.AccountID + fmt.Sprint(event.Activity) + _, duplicate := dups[key] + if duplicate { + continue + } else { + dups[key] = struct{}{} + } + } + filtered = append(filtered, event) + } + + return filtered, nil +} + +func (em *EventsManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.Activity, + meta map[string]any) { + + go func() { + _, err := em.store.Save(&activity.Event{ + Timestamp: time.Now().UTC(), + Activity: activityID, + InitiatorID: initiatorID, + TargetID: targetID, + AccountID: accountID, + Meta: meta, + }) + if err != nil { + // todo add metric + log.Errorf("received an error while storing an activity event, error: %s", err) + } + }() + +} diff --git a/management/server/management_refactor/server/events/events_manager_test.go b/management/server/management_refactor/server/events/events_manager_test.go new file mode 100644 index 000000000..a32c6dd8d --- /dev/null +++ b/management/server/management_refactor/server/events/events_manager_test.go @@ -0,0 +1,76 @@ +package events + +// import ( +// "testing" +// "time" +// +// "github.com/stretchr/testify/assert" +// +// "github.com/netbirdio/netbird/management/server/activity" +// ) +// +// func generateAndStoreEvents(t *testing.T, manager *EventsManager, typ activity.Activity, initiatorID, targetID, +// accountID string, count int) { +// t.Helper() +// for i := 0; i < count; i++ { +// _, err := manager.store.Save(&activity.Event{ +// Timestamp: time.Now().UTC(), +// Activity: typ, +// InitiatorID: initiatorID, +// TargetID: targetID, +// AccountID: accountID, +// }) +// if err != nil { +// t.Fatal(err) +// } +// } +// } +// +// func TestDefaultAccountManager_GetEvents(t *testing.T) { +// manager, err := createManager(t) +// if err != nil { +// return +// } +// +// accountID := "accountID" +// +// t.Run("get empty events list", func(t *testing.T) { +// events, err := manager.GetEvents(accountID, userID) +// if err != nil { +// return +// } +// assert.Len(t, events, 0) +// _ = manager.eventStore.Close() //nolint +// }) +// +// t.Run("get events", func(t *testing.T) { +// generateAndStoreEvents(t, manager, activity.PeerAddedByUser, userID, "peer", accountID, 10) +// events, err := manager.GetEvents(accountID, userID) +// if err != nil { +// return +// } +// +// assert.Len(t, events, 10) +// _ = manager.eventStore.Close() //nolint +// }) +// +// t.Run("get events without duplicates", func(t *testing.T) { +// generateAndStoreEvents(t, manager, activity.UserJoined, userID, "", accountID, 10) +// events, err := manager.GetEvents(accountID, userID) +// if err != nil { +// return +// } +// assert.Len(t, events, 1) +// _ = manager.eventStore.Close() //nolint +// }) +// } +// +// func createManager(t *testing.T) (*EventsManager, error) { +// t.Helper() +// store, err := createStore(t) +// if err != nil { +// return nil, err +// } +// eventStore := &activity.InMemoryEventStore{} +// return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, false) +// } diff --git a/management/server/management_refactor/server/groups/group.go b/management/server/management_refactor/server/groups/group.go new file mode 100644 index 000000000..0624edf5e --- /dev/null +++ b/management/server/management_refactor/server/groups/group.go @@ -0,0 +1,333 @@ +package groups + +import ( + "fmt" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/status" +) + +type GroupLinkError struct { + Resource string + Name string +} + +func (e *GroupLinkError) Error() string { + return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name) +} + +// Group of the peers for ACL +type Group struct { + // ID of the group + ID string + + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + + // Name visible in the UI + Name string + + // Issued of the group + Issued string + + // Peers list of the group + Peers []string `gorm:"serializer:json"` + + IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` +} + +// EventMeta returns activity event meta related to the group +func (g *Group) EventMeta() map[string]any { + return map[string]any{"name": g.Name} +} + +func (g *Group) Copy() *Group { + group := &Group{ + ID: g.ID, + Name: g.Name, + Issued: g.Issued, + Peers: make([]string, len(g.Peers)), + IntegrationReference: g.IntegrationReference, + } + copy(group.Peers, g.Peers) + return group +} + +// GetGroup object of the peers +func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, err + } + + group, ok := account.Groups[groupID] + if ok { + return group, nil + } + + return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID) +} + +// SaveGroup object of the peers +func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *Group) error { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return err + } + oldGroup, exists := account.Groups[newGroup.ID] + account.Groups[newGroup.ID] = newGroup + + account.Network.IncSerial() + if err = am.Store.SaveAccount(account); err != nil { + return err + } + + am.updateAccountPeers(account) + + // the following snippet tracks the activity and stores the group events in the event store. + // It has to happen after all the operations have been successfully performed. + addedPeers := make([]string, 0) + removedPeers := make([]string, 0) + if exists { + addedPeers = difference(newGroup.Peers, oldGroup.Peers) + removedPeers = difference(oldGroup.Peers, newGroup.Peers) + } else { + addedPeers = append(addedPeers, newGroup.Peers...) + am.StoreEvent(userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta()) + } + + for _, p := range addedPeers { + peer := account.Peers[p] + if peer == nil { + log.Errorf("peer %s not found under account %s while saving group", p, accountID) + continue + } + am.StoreEvent(userID, peer.ID, accountID, activity.GroupAddedToPeer, + map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(), + "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + }) + } + + for _, p := range removedPeers { + peer := account.Peers[p] + if peer == nil { + log.Errorf("peer %s not found under account %s while saving group", p, accountID) + continue + } + am.StoreEvent(userID, peer.ID, accountID, activity.GroupRemovedFromPeer, + map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(), + "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + }) + } + + return nil +} + +// difference returns the elements in `a` that aren't in `b`. +func difference(a, b []string) []string { + mb := make(map[string]struct{}, len(b)) + for _, x := range b { + mb[x] = struct{}{} + } + var diff []string + for _, x := range a { + if _, found := mb[x]; !found { + diff = append(diff, x) + } + } + return diff +} + +// DeleteGroup object of the peers +func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error { + unlock := am.Store.AcquireAccountLock(accountId) + defer unlock() + + account, err := am.Store.GetAccount(accountId) + if err != nil { + return err + } + + g, ok := account.Groups[groupID] + if !ok { + return nil + } + + // disable a deleting integration group if the initiator is not an admin service user + if g.Issued == GroupIssuedIntegration { + executingUser := account.Users[userId] + if executingUser == nil { + return status.Errorf(status.NotFound, "user not found") + } + if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { + return status.Errorf(status.PermissionDenied, "only admins service user can delete integration group") + } + } + + // check route links + for _, r := range account.Routes { + for _, g := range r.Groups { + if g == groupID { + return &GroupLinkError{"route", r.NetID} + } + } + } + + // check DNS links + for _, dns := range account.NameServerGroups { + for _, g := range dns.Groups { + if g == groupID { + return &GroupLinkError{"name server groups", dns.Name} + } + } + } + + // check ACL links + for _, policy := range account.Policies { + for _, rule := range policy.Rules { + for _, src := range rule.Sources { + if src == groupID { + return &GroupLinkError{"policy", policy.Name} + } + } + + for _, dst := range rule.Destinations { + if dst == groupID { + return &GroupLinkError{"policy", policy.Name} + } + } + } + } + + // check setup key links + for _, setupKey := range account.SetupKeys { + for _, grp := range setupKey.AutoGroups { + if grp == groupID { + return &GroupLinkError{"setup key", setupKey.Name} + } + } + } + + // check user links + for _, user := range account.Users { + for _, grp := range user.AutoGroups { + if grp == groupID { + return &GroupLinkError{"user", user.Id} + } + } + } + + // check DisabledManagementGroups + for _, disabledMgmGrp := range account.DNSSettings.DisabledManagementGroups { + if disabledMgmGrp == groupID { + return &GroupLinkError{"disabled DNS management groups", g.Name} + } + } + + delete(account.Groups, groupID) + + account.Network.IncSerial() + if err = am.Store.SaveAccount(account); err != nil { + return err + } + + am.StoreEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta()) + + am.updateAccountPeers(account) + + return nil +} + +// ListGroups objects of the peers +func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, err + } + + groups := make([]*Group, 0, len(account.Groups)) + for _, item := range account.Groups { + groups = append(groups, item) + } + + return groups, nil +} + +// GroupAddPeer appends peer to the group +func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return err + } + + group, ok := account.Groups[groupID] + if !ok { + return status.Errorf(status.NotFound, "group with ID %s not found", groupID) + } + + add := true + for _, itemID := range group.Peers { + if itemID == peerID { + add = false + break + } + } + if add { + group.Peers = append(group.Peers, peerID) + } + + account.Network.IncSerial() + if err = am.Store.SaveAccount(account); err != nil { + return err + } + + am.updateAccountPeers(account) + + return nil +} + +// GroupDeletePeer removes peer from the group +func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return err + } + + group, ok := account.Groups[groupID] + if !ok { + return status.Errorf(status.NotFound, "group with ID %s not found", groupID) + } + + account.Network.IncSerial() + for i, itemID := range group.Peers { + if itemID == peerID { + group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) + if err := am.Store.SaveAccount(account); err != nil { + return err + } + } + } + + am.updateAccountPeers(account) + + return nil +} diff --git a/management/server/management_refactor/server/grpcserver.go b/management/server/management_refactor/server/grpcserver.go new file mode 100644 index 000000000..6ebd6192c --- /dev/null +++ b/management/server/management_refactor/server/grpcserver.go @@ -0,0 +1,607 @@ +package server + +import ( + "context" + "fmt" + "strings" + "time" + + pb "github.com/golang/protobuf/proto" // nolint + "github.com/golang/protobuf/ptypes/timestamp" + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "google.golang.org/grpc/codes" + gRPCPeer "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/management_refactor/server/accounts" + "github.com/netbirdio/netbird/management/server/management_refactor/server/peers" + internalStatus "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/telemetry" +) + +// GRPCServer an instance of a Management gRPC API server +type GRPCServer struct { + accountManager accounts.AccountManager + wgKey wgtypes.Key + proto.UnimplementedManagementServiceServer + peersUpdateManager *peers.PeersUpdateManager + config *Config + turnCredentialsManager TURNCredentialsManager + jwtValidator *jwtclaims.JWTValidator + jwtClaimsExtractor *jwtclaims.ClaimsExtractor + appMetrics telemetry.AppMetrics + ephemeralManager *peers.EphemeralManager +} + +// NewServer creates a new Management server +func NewServer(config *Config, accountManager accounts.AccountManager, peersUpdateManager *peers.PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *peers.EphemeralManager) (*GRPCServer, error) { + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, err + } + + var jwtValidator *jwtclaims.JWTValidator + + if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { + jwtValidator, err = jwtclaims.NewJWTValidator( + config.HttpConfig.AuthIssuer, + config.GetAuthAudiences(), + config.HttpConfig.AuthKeysLocation, + config.HttpConfig.IdpSignKeyRefreshEnabled, + ) + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err) + } + } else { + log.Debug("unable to use http config to create new jwt middleware") + } + + if appMetrics != nil { + // update gauge based on number of connected peers which is equal to open gRPC streams + err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { + return int64(len(peersUpdateManager.PeerChannels)) + }) + if err != nil { + return nil, err + } + } + + var audience, userIDClaim string + if config.HttpConfig != nil { + audience = config.HttpConfig.AuthAudience + userIDClaim = config.HttpConfig.AuthUserIDClaim + } + jwtClaimsExtractor := jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(audience), + jwtclaims.WithUserIDClaim(userIDClaim), + ) + + return &GRPCServer{ + wgKey: key, + // peerKey -> event channel + peersUpdateManager: peersUpdateManager, + accountManager: accountManager, + config: config, + turnCredentialsManager: turnCredentialsManager, + jwtValidator: jwtValidator, + jwtClaimsExtractor: jwtClaimsExtractor, + appMetrics: appMetrics, + ephemeralManager: ephemeralManager, + }, nil +} + +func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) { + // todo introduce something more meaningful with the key expiration/rotation + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountGetKeyRequest() + } + now := time.Now().Add(24 * time.Hour) + secs := int64(now.Second()) + nanos := int32(now.Nanosecond()) + expiresAt := ×tamp.Timestamp{Seconds: secs, Nanos: nanos} + + return &proto.ServerKeyResponse{ + Key: s.wgKey.PublicKey().String(), + ExpiresAt: expiresAt, + }, nil +} + +// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and +// notifies the connected peer of any updates (e.g. new peers under the same account) +func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { + reqStart := time.Now() + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountSyncRequest() + } + p, ok := gRPCPeer.FromContext(srv.Context()) + if ok { + log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, p.Addr.String()) + } + + syncReq := &proto.SyncRequest{} + peerKey, err := s.parseRequest(req, syncReq) + if err != nil { + return err + } + + peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()}) + if err != nil { + return mapError(err) + } + + err = s.sendInitialSync(peerKey, peer, netMap, srv) + if err != nil { + log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) + return err + } + + updates := s.peersUpdateManager.CreateChannel(peer.ID) + + s.ephemeralManager.OnPeerConnected(peer) + + err = s.accountManager.MarkPeerConnected(peerKey.String(), true) + if err != nil { + log.Warnf("failed marking peer as connected %s %v", peerKey, err) + } + + if s.config.TURNConfig.TimeBasedCredentials { + s.turnCredentialsManager.SetupRefresh(peer.ID) + } + + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart)) + } + + // keep a connection to the peer and send updates when available + for { + select { + // condition when there are some updates + case update, open := <-updates: + + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1) + } + + if !open { + log.Debugf("updates channel for peer %s was closed", peerKey.String()) + s.cancelPeerRoutines(peer) + return nil + } + log.Debugf("received an update for peer %s", peerKey.String()) + + encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) + if err != nil { + s.cancelPeerRoutines(peer) + return status.Errorf(codes.Internal, "failed processing update message") + } + + err = srv.SendMsg(&proto.EncryptedMessage{ + WgPubKey: s.wgKey.PublicKey().String(), + Body: encryptedResp, + }) + if err != nil { + s.cancelPeerRoutines(peer) + return status.Errorf(codes.Internal, "failed sending update message") + } + log.Debugf("sent an update to peer %s", peerKey.String()) + // condition when client <-> server connection has been terminated + case <-srv.Context().Done(): + // happens when connection drops, e.g. client disconnects + log.Debugf("stream of peer %s has been closed", peerKey.String()) + s.cancelPeerRoutines(peer) + return srv.Context().Err() + } + } +} + +func (s *GRPCServer) cancelPeerRoutines(peer *Peer) { + s.peersUpdateManager.CloseChannel(peer.ID) + s.turnCredentialsManager.CancelRefresh(peer.ID) + _ = s.accountManager.MarkPeerConnected(peer.Key, false) + s.ephemeralManager.OnPeerDisconnected(peer) +} + +func (s *GRPCServer) validateToken(jwtToken string) (string, error) { + if s.jwtValidator == nil { + return "", status.Error(codes.Internal, "no jwt validator set") + } + + token, err := s.jwtValidator.ValidateAndParse(jwtToken) + if err != nil { + return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err) + } + claims := s.jwtClaimsExtractor.FromToken(token) + // we need to call this method because if user is new, we will automatically add it to existing or create a new account + _, _, err = s.accountManager.GetAccountFromToken(claims) + if err != nil { + return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) + } + + return claims.UserId, nil +} + +// maps internal internalStatus.Error to gRPC status.Error +func mapError(err error) error { + if e, ok := internalStatus.FromError(err); ok { + switch e.Type() { + case internalStatus.PermissionDenied: + return status.Errorf(codes.PermissionDenied, e.Message) + case internalStatus.Unauthorized: + return status.Errorf(codes.PermissionDenied, e.Message) + case internalStatus.Unauthenticated: + return status.Errorf(codes.PermissionDenied, e.Message) + case internalStatus.PreconditionFailed: + return status.Errorf(codes.FailedPrecondition, e.Message) + case internalStatus.NotFound: + return status.Errorf(codes.NotFound, e.Message) + default: + } + } + log.Errorf("got an unhandled error: %s", err) + return status.Errorf(codes.Internal, "failed handling request") +} + +func extractPeerMeta(loginReq *proto.LoginRequest) PeerSystemMeta { + return PeerSystemMeta{ + Hostname: loginReq.GetMeta().GetHostname(), + GoOS: loginReq.GetMeta().GetGoOS(), + Kernel: loginReq.GetMeta().GetKernel(), + Core: loginReq.GetMeta().GetCore(), + Platform: loginReq.GetMeta().GetPlatform(), + OS: loginReq.GetMeta().GetOS(), + WtVersion: loginReq.GetMeta().GetWiretrusteeVersion(), + UIVersion: loginReq.GetMeta().GetUiVersion(), + } +} + +func (s *GRPCServer) parseRequest(req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { + peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) + if err != nil { + log.Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey) + return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey) + } + + err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, parsed) + if err != nil { + return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message") + } + + return peerKey, nil +} + +// Login endpoint first checks whether peer is registered under any account +// In case it is, the login is successful +// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer. +// In case of the successful registration login is also successful +func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { + reqStart := time.Now() + defer func() { + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart)) + } + }() + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequest() + } + p, ok := gRPCPeer.FromContext(ctx) + if ok { + log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, p.Addr.String()) + } + + loginReq := &proto.LoginRequest{} + peerKey, err := s.parseRequest(req, loginReq) + if err != nil { + return nil, err + } + + if loginReq.GetMeta() == nil { + msg := status.Errorf(codes.FailedPrecondition, + "peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), + p.Addr.String()) + log.Warn(msg) + return nil, msg + } + + userID := "" + // JWT token is not always provided, it is fine for userID to be empty cuz it might be that peer is already registered, + // or it uses a setup key to register. + if loginReq.GetJwtToken() != "" { + userID, err = s.validateToken(loginReq.GetJwtToken()) + if err != nil { + log.Warnf("failed validating JWT token sent from peer %s", peerKey) + return nil, mapError(err) + } + } + var sshKey []byte + if loginReq.GetPeerKeys() != nil { + sshKey = loginReq.GetPeerKeys().GetSshPubKey() + } + + peer, netMap, err := s.accountManager.LoginPeer(PeerLogin{ + WireGuardPubKey: peerKey.String(), + SSHKey: string(sshKey), + Meta: extractPeerMeta(loginReq), + UserID: userID, + SetupKey: loginReq.GetSetupKey(), + }) + + if err != nil { + log.Warnf("failed logging in peer %s", peerKey) + return nil, mapError(err) + } + + // if the login request contains setup key then it is a registration request + if loginReq.GetSetupKey() != "" { + s.ephemeralManager.OnPeerDisconnected(peer) + } + + // if peer has reached this point then it has logged in + loginResp := &proto.LoginResponse{ + WiretrusteeConfig: toWiretrusteeConfig(s.config, nil), + PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()), + } + encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) + if err != nil { + log.Warnf("failed encrypting peer %s message", peer.ID) + return nil, status.Errorf(codes.Internal, "failed logging in peer") + } + + return &proto.EncryptedMessage{ + WgPubKey: s.wgKey.PublicKey().String(), + Body: encryptedResp, + }, nil +} + +func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol { + switch configProto { + case UDP: + return proto.HostConfig_UDP + case DTLS: + return proto.HostConfig_DTLS + case HTTP: + return proto.HostConfig_HTTP + case HTTPS: + return proto.HostConfig_HTTPS + case TCP: + return proto.HostConfig_TCP + default: + panic(fmt.Errorf("unexpected config protocol type %v", configProto)) + } +} + +func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig { + if config == nil { + return nil + } + var stuns []*proto.HostConfig + for _, stun := range config.Stuns { + stuns = append(stuns, &proto.HostConfig{ + Uri: stun.URI, + Protocol: ToResponseProto(stun.Proto), + }) + } + var turns []*proto.ProtectedHostConfig + for _, turn := range config.TURNConfig.Turns { + var username string + var password string + if turnCredentials != nil { + username = turnCredentials.Username + password = turnCredentials.Password + } else { + username = turn.Username + password = turn.Password + } + turns = append(turns, &proto.ProtectedHostConfig{ + HostConfig: &proto.HostConfig{ + Uri: turn.URI, + Protocol: ToResponseProto(turn.Proto), + }, + User: username, + Password: password, + }) + } + + return &proto.WiretrusteeConfig{ + Stuns: stuns, + Turns: turns, + Signal: &proto.HostConfig{ + Uri: config.Signal.URI, + Protocol: ToResponseProto(config.Signal.Proto), + }, + } +} + +func toPeerConfig(peer *Peer, network *Network, dnsName string) *proto.PeerConfig { + netmask, _ := network.Net.Mask.Size() + fqdn := peer.FQDN(dnsName) + return &proto.PeerConfig{ + Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network + SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, + Fqdn: fqdn, + } +} + +func toRemotePeerConfig(peers []*Peer, dnsName string) []*proto.RemotePeerConfig { + remotePeers := []*proto.RemotePeerConfig{} + for _, rPeer := range peers { + fqdn := rPeer.FQDN(dnsName) + remotePeers = append(remotePeers, &proto.RemotePeerConfig{ + WgPubKey: rPeer.Key, + AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)}, + SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, + Fqdn: fqdn, + }) + } + return remotePeers +} + +func toSyncResponse(config *Config, peer *Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse { + wtConfig := toWiretrusteeConfig(config, turnCredentials) + + pConfig := toPeerConfig(peer, networkMap.Network, dnsName) + + remotePeers := toRemotePeerConfig(networkMap.Peers, dnsName) + + routesUpdate := toProtocolRoutes(networkMap.Routes) + + dnsUpdate := toProtocolDNSConfig(networkMap.DNSConfig) + + offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName) + + firewallRules := toProtocolFirewallRules(networkMap.FirewallRules) + + return &proto.SyncResponse{ + WiretrusteeConfig: wtConfig, + PeerConfig: pConfig, + RemotePeers: remotePeers, + RemotePeersIsEmpty: len(remotePeers) == 0, + NetworkMap: &proto.NetworkMap{ + Serial: networkMap.Network.CurrentSerial(), + PeerConfig: pConfig, + RemotePeers: remotePeers, + OfflinePeers: offlinePeers, + RemotePeersIsEmpty: len(remotePeers) == 0, + Routes: routesUpdate, + DNSConfig: dnsUpdate, + FirewallRules: firewallRules, + FirewallRulesIsEmpty: len(firewallRules) == 0, + }, + } +} + +// IsHealthy indicates whether the service is healthy +func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) { + return &proto.Empty{}, nil +} + +// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization +func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error { + // make secret time based TURN credentials optional + var turnCredentials *TURNCredentials + if s.config.TURNConfig.TimeBasedCredentials { + creds := s.turnCredentialsManager.GenerateCredentials() + turnCredentials = &creds + } else { + turnCredentials = nil + } + plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain()) + + encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) + if err != nil { + return status.Errorf(codes.Internal, "error handling request") + } + + err = srv.Send(&proto.EncryptedMessage{ + WgPubKey: s.wgKey.PublicKey().String(), + Body: encryptedResp, + }) + + if err != nil { + log.Errorf("failed sending SyncResponse %v", err) + return status.Errorf(codes.Internal, "error handling request") + } + + return nil +} + +// GetDeviceAuthorizationFlow returns a device authorization flow information +// This is used for initiating an Oauth 2 device authorization grant flow +// which will be used by our clients to Login +func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { + peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) + if err != nil { + errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey) + log.Warn(errMSG) + return nil, status.Error(codes.InvalidArgument, errMSG) + } + + err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{}) + if err != nil { + errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey) + log.Warn(errMSG) + return nil, status.Error(codes.InvalidArgument, errMSG) + } + + if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(NONE) { + return nil, status.Error(codes.NotFound, "no device authorization flow information available") + } + + provider, ok := proto.DeviceAuthorizationFlowProvider_value[strings.ToUpper(s.config.DeviceAuthorizationFlow.Provider)] + if !ok { + return nil, status.Errorf(codes.InvalidArgument, "no provider found in the protocol for %s", s.config.DeviceAuthorizationFlow.Provider) + } + + flowInfoResp := &proto.DeviceAuthorizationFlow{ + Provider: proto.DeviceAuthorizationFlowProvider(provider), + ProviderConfig: &proto.ProviderConfig{ + ClientID: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientID, + ClientSecret: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientSecret, + Domain: s.config.DeviceAuthorizationFlow.ProviderConfig.Domain, + Audience: s.config.DeviceAuthorizationFlow.ProviderConfig.Audience, + DeviceAuthEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint, + TokenEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint, + Scope: s.config.DeviceAuthorizationFlow.ProviderConfig.Scope, + UseIDToken: s.config.DeviceAuthorizationFlow.ProviderConfig.UseIDToken, + }, + } + + encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp) + if err != nil { + return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information") + } + + return &proto.EncryptedMessage{ + WgPubKey: s.wgKey.PublicKey().String(), + Body: encryptedResp, + }, nil +} + +// GetPKCEAuthorizationFlow returns a pkce authorization flow information +// This is used for initiating an Oauth 2 pkce authorization grant flow +// which will be used by our clients to Login +func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { + peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) + if err != nil { + errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey) + log.Warn(errMSG) + return nil, status.Error(codes.InvalidArgument, errMSG) + } + + err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{}) + if err != nil { + errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey) + log.Warn(errMSG) + return nil, status.Error(codes.InvalidArgument, errMSG) + } + + if s.config.PKCEAuthorizationFlow == nil { + return nil, status.Error(codes.NotFound, "no pkce authorization flow information available") + } + + flowInfoResp := &proto.PKCEAuthorizationFlow{ + ProviderConfig: &proto.ProviderConfig{ + Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience, + ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID, + ClientSecret: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientSecret, + TokenEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint, + AuthorizationEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint, + Scope: s.config.PKCEAuthorizationFlow.ProviderConfig.Scope, + RedirectURLs: s.config.PKCEAuthorizationFlow.ProviderConfig.RedirectURLs, + UseIDToken: s.config.PKCEAuthorizationFlow.ProviderConfig.UseIDToken, + }, + } + + encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp) + if err != nil { + return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information") + } + + return &proto.EncryptedMessage{ + WgPubKey: s.wgKey.PublicKey().String(), + Body: encryptedResp, + }, nil +} diff --git a/management/server/management_refactor/server/nameservers/nameserver.go b/management/server/management_refactor/server/nameservers/nameserver.go new file mode 100644 index 000000000..50e9ec68b --- /dev/null +++ b/management/server/management_refactor/server/nameservers/nameserver.go @@ -0,0 +1,286 @@ +package nameservers + +import ( + "errors" + "regexp" + "unicode/utf8" + + "github.com/miekg/dns" + "github.com/rs/xid" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/status" +) + +const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` + +// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs +func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, err + } + + nsGroup, found := account.NameServerGroups[nsGroupID] + if found { + return nsGroup.Copy(), nil + } + + return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID) +} + +// CreateNameServerGroup creates and saves a new nameserver group +func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, err + } + + newNSGroup := &nbdns.NameServerGroup{ + ID: xid.New().String(), + Name: name, + Description: description, + NameServers: nameServerList, + Groups: groups, + Enabled: enabled, + Primary: primary, + Domains: domains, + SearchDomainsEnabled: searchDomainEnabled, + } + + err = validateNameServerGroup(false, newNSGroup, account) + if err != nil { + return nil, err + } + + if account.NameServerGroups == nil { + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup) + } + + account.NameServerGroups[newNSGroup.ID] = newNSGroup + + account.Network.IncSerial() + err = am.Store.SaveAccount(account) + if err != nil { + return nil, err + } + + am.updateAccountPeers(account) + + am.StoreEvent(userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) + + return newNSGroup.Copy(), nil +} + +// SaveNameServerGroup saves nameserver group +func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + if nsGroupToSave == nil { + return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") + } + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return err + } + + err = validateNameServerGroup(true, nsGroupToSave, account) + if err != nil { + return err + } + + account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave + + account.Network.IncSerial() + err = am.Store.SaveAccount(account) + if err != nil { + return err + } + + am.updateAccountPeers(account) + + am.StoreEvent(userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) + + return nil +} + +// DeleteNameServerGroup deletes nameserver group with nsGroupID +func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error { + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return err + } + + nsGroup := account.NameServerGroups[nsGroupID] + if nsGroup == nil { + return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID) + } + delete(account.NameServerGroups, nsGroupID) + + account.Network.IncSerial() + err = am.Store.SaveAccount(account) + if err != nil { + return err + } + + am.updateAccountPeers(account) + + am.StoreEvent(userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) + + return nil +} + +// ListNameServerGroups returns a list of nameserver groups from account +func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) { + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, err + } + + nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups)) + for _, item := range account.NameServerGroups { + nsGroups = append(nsGroups, item.Copy()) + } + + return nsGroups, nil +} + +func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error { + nsGroupID := "" + if existingGroup { + nsGroupID = nameserverGroup.ID + _, found := account.NameServerGroups[nsGroupID] + if !found { + return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID) + } + } + + err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled) + if err != nil { + return err + } + + err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups) + if err != nil { + return err + } + + err = validateNSList(nameserverGroup.NameServers) + if err != nil { + return err + } + + err = validateGroups(nameserverGroup.Groups, account.Groups) + if err != nil { + return err + } + + return nil +} + +func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error { + if !primary && len(domains) == 0 { + return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+ + " it should be primary or have at least one domain") + } + if primary && len(domains) != 0 { + return status.Errorf(status.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+ + " you should set either primary or domain") + } + + if primary && searchDomainsEnabled { + return status.Errorf(status.InvalidArgument, "nameserver group primary status is true and search domains is enabled,"+ + " you should not set search domains for primary nameservers") + } + + for _, domain := range domains { + if err := validateDomain(domain); err != nil { + return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s %q", domain, err) + } + } + return nil +} + +func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error { + if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" { + return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar) + } + + for _, nsGroup := range nsGroupMap { + if name == nsGroup.Name && nsGroup.ID != nsGroupID { + return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name) + } + } + + return nil +} + +func validateNSList(list []nbdns.NameServer) error { + nsListLenght := len(list) + if nsListLenght == 0 || nsListLenght > 2 { + return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 2, got %d", len(list)) + } + return nil +} + +func validateGroups(list []string, groups map[string]*Group) error { + if len(list) == 0 { + return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty") + } + + for _, id := range list { + if id == "" { + return status.Errorf(status.InvalidArgument, "group ID should not be empty string") + } + found := false + for groupID := range groups { + if id == groupID { + found = true + break + } + } + if !found { + return status.Errorf(status.InvalidArgument, "group id %s not found", id) + } + } + + return nil +} + +func validateDomain(domain string) error { + domainMatcher := regexp.MustCompile(domainPattern) + if !domainMatcher.MatchString(domain) { + return errors.New("domain should consists of only letters, numbers, and hyphens with no leading, trailing hyphens, or spaces") + } + + labels, valid := dns.IsDomainName(domain) + if !valid { + return errors.New("invalid domain name") + } + + if labels < 2 { + return errors.New("domain should consists of a minimum of two labels") + } + + return nil +} diff --git a/management/server/management_refactor/server/network/network.go b/management/server/management_refactor/server/network/network.go new file mode 100644 index 000000000..f2fdadde3 --- /dev/null +++ b/management/server/management_refactor/server/network/network.go @@ -0,0 +1,148 @@ +package network + +import ( + "math/rand" + "net" + "sync" + "time" + + "github.com/c-robinson/iplib" + "github.com/rs/xid" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/management_refactor/server/peers" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/route" +) + +const ( + // SubnetSize is a size of the subnet of the global network, e.g. 100.77.0.0/16 + SubnetSize = 16 + // NetSize is a global network size 100.64.0.0/10 + NetSize = 10 + + // AllowedIPsFormat generates Wireguard AllowedIPs format (e.g. 100.64.30.1/32) + AllowedIPsFormat = "%s/32" +) + +type NetworkMap struct { + Peers []*peers.Peer + Network *Network + Routes []*route.Route + DNSConfig nbdns.Config + OfflinePeers []*peers.Peer + FirewallRules []*FirewallRule +} + +type Network struct { + Identifier string `json:"id"` + Net net.IPNet `gorm:"serializer:gob"` + Dns string + // Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added). + // Used to synchronize state to the client apps. + Serial uint64 + + mu sync.Mutex `json:"-" gorm:"-"` +} + +// NewNetwork creates a new Network initializing it with a Serial=0 +// It takes a random /16 subnet from 100.64.0.0/10 (64 different subnets) +func NewNetwork() *Network { + + n := iplib.NewNet4(net.ParseIP("100.64.0.0"), NetSize) + sub, _ := n.Subnet(SubnetSize) + + s := rand.NewSource(time.Now().Unix()) + r := rand.New(s) + intn := r.Intn(len(sub)) + + return &Network{ + Identifier: xid.New().String(), + Net: sub[intn].IPNet, + Dns: "", + Serial: 0} +} + +// IncSerial increments Serial by 1 reflecting that the network state has been changed +func (n *Network) IncSerial() { + n.mu.Lock() + defer n.mu.Unlock() + n.Serial = n.Serial + 1 +} + +// CurrentSerial returns the Network.Serial of the network (latest state id) +func (n *Network) CurrentSerial() uint64 { + n.mu.Lock() + defer n.mu.Unlock() + return n.Serial +} + +func (n *Network) Copy() *Network { + return &Network{ + Identifier: n.Identifier, + Net: n.Net, + Dns: n.Dns, + Serial: n.Serial, + } +} + +// AllocatePeerIP pics an available IP from an net.IPNet. +// This method considers already taken IPs and reuses IPs if there are gaps in takenIps +// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3 +func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) { + takenIPMap := make(map[string]struct{}) + takenIPMap[ipNet.IP.String()] = struct{}{} + for _, ip := range takenIps { + takenIPMap[ip.String()] = struct{}{} + } + + ips, _ := generateIPs(&ipNet, takenIPMap) + + if len(ips) == 0 { + return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String()) + } + + // pick a random IP + s := rand.NewSource(time.Now().Unix()) + r := rand.New(s) + intn := r.Intn(len(ips)) + + return ips[intn], nil +} + +// generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list +func generateIPs(ipNet *net.IPNet, exclusions map[string]struct{}) ([]net.IP, int) { + + var ips []net.IP + for ip := ipNet.IP.Mask(ipNet.Mask); ipNet.Contains(ip); incIP(ip) { + if _, ok := exclusions[ip.String()]; !ok && ip[3] != 0 { + ips = append(ips, copyIP(ip)) + } + } + + // remove network address, broadcast and Fake DNS resolver address + lenIPs := len(ips) + switch { + case lenIPs < 2: + return ips, lenIPs + case lenIPs < 3: + return ips[1 : len(ips)-1], lenIPs - 2 + default: + return ips[1 : len(ips)-2], lenIPs - 3 + } +} + +func copyIP(ip net.IP) net.IP { + dup := make(net.IP, len(ip)) + copy(dup, ip) + return dup +} + +func incIP(ip net.IP) { + for j := len(ip) - 1; j >= 0; j-- { + ip[j]++ + if ip[j] > 0 { + break + } + } +} diff --git a/management/server/management_refactor/server/network/network_manager.go b/management/server/management_refactor/server/network/network_manager.go new file mode 100644 index 000000000..c66a168bd --- /dev/null +++ b/management/server/management_refactor/server/network/network_manager.go @@ -0,0 +1,161 @@ +package network + +import ( + "strconv" + "strings" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/management_refactor/server/access_control" + "github.com/netbirdio/netbird/management/server/management_refactor/server/peers" +) + +type NetworkManager interface { + GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap +} + +type DefaultNetworkManager struct { + accessControlManager access_control.AccessControlManager +} + +func (nm *DefaultNetworkManager) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap { + aclPeers, firewallRules := getPeerConnectionResources(peerID) + // exclude expired peers + var peersToConnect []*peers.Peer + var expiredPeers []*peers.Peer + for _, p := range aclPeers { + expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration) + if a.Settings.PeerLoginExpirationEnabled && expired { + expiredPeers = append(expiredPeers, p) + continue + } + peersToConnect = append(peersToConnect, p) + } + + routesUpdate := a.getRoutesToSync(peerID, peersToConnect) + + dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) + dnsUpdate := nbdns.Config{ + ServiceEnable: dnsManagementStatus, + } + + if dnsManagementStatus { + var zones []nbdns.CustomZone + peersCustomZone := getPeersCustomZone(a, dnsDomain) + if peersCustomZone.Domain != "" { + zones = append(zones, peersCustomZone) + } + dnsUpdate.CustomZones = zones + dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) + } + + return &NetworkMap{ + Peers: peersToConnect, + Network: a.Network.Copy(), + Routes: routesUpdate, + DNSConfig: dnsUpdate, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, + } +} + +// getPeerConnectionResources for a given peer +// +// This function returns the list of peers and firewall rules that are applicable to a given peer. +func (nm *DefaultNetworkManager) getPeerConnectionResources(peerID string) ([]*Peer, []*FirewallRule) { + generateResources, getAccumulatedResources := a.connResourcesGenerator() + for _, policy := range nm.accessControlManager.Policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID) + destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID) + + if rule.Bidirectional { + if peerInSources { + generateResources(rule, destinationPeers, firewallRuleDirectionIN) + } + if peerInDestinations { + generateResources(rule, sourcePeers, firewallRuleDirectionOUT) + } + } + + if peerInSources { + generateResources(rule, destinationPeers, firewallRuleDirectionOUT) + } + + if peerInDestinations { + generateResources(rule, sourcePeers, firewallRuleDirectionIN) + } + } + } + + return getAccumulatedResources() +} + +// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls +// +// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. +// It safe to call the generator function multiple times for same peer and different rules no duplicates will be +// generated. The accumulator function returns the result of all the generator calls. +func (nm *DefaultNetworkManager) connResourcesGenerator() (func(*access_control.PolicyRule, []*peers.Peer, int), func() ([]*peers.Peer, []*access_control.FirewallRule)) { + rulesExists := make(map[string]struct{}) + peersExists := make(map[string]struct{}) + rules := make([]*FirewallRule, 0) + peers := make([]*peers.Peer, 0) + + all, err := a.GetGroupAll() + if err != nil { + log.Errorf("failed to get group all: %v", err) + all = &Group{} + } + + return func(rule *PolicyRule, groupPeers []*Peer, direction int) { + isAll := (len(all.Peers) - 1) == len(groupPeers) + for _, peer := range groupPeers { + if peer == nil { + continue + } + if _, ok := peersExists[peer.ID]; !ok { + peers = append(peers, peer) + peersExists[peer.ID] = struct{}{} + } + + fr := FirewallRule{ + PeerIP: peer.IP.String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + + if isAll { + fr.PeerIP = "0.0.0.0" + } + + ruleID := (rule.ID + fr.PeerIP + strconv.Itoa(direction) + + fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")) + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + if len(rule.Ports) == 0 { + rules = append(rules, &fr) + continue + } + + for _, port := range rule.Ports { + pr := fr // clone rule and add set new port + pr.Port = port + rules = append(rules, &pr) + } + } + }, func() ([]*Peer, []*FirewallRule) { + return peers, rules + } +} diff --git a/management/server/management_refactor/server/peers/ephemeral.go b/management/server/management_refactor/server/peers/ephemeral.go new file mode 100644 index 000000000..d8e5f9e23 --- /dev/null +++ b/management/server/management_refactor/server/peers/ephemeral.go @@ -0,0 +1,225 @@ +package peers + +import ( + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/management_refactor/accounts" +) + +const ( + ephemeralLifeTime = 10 * time.Minute +) + +var ( + timeNow = time.Now +) + +type ephemeralPeer struct { + id string + account *accounts.Account + deadline time.Time + next *ephemeralPeer +} + +// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it +// in worst case we will get invalid error message in this manager. + +// EphemeralManager keep a list of ephemeral peers. After ephemeralLifeTime inactivity the peer will be deleted +// automatically. Inactivity means the peer disconnected from the Management server. +type EphemeralManager struct { + store Store + accountManager AccountManager + + headPeer *ephemeralPeer + tailPeer *ephemeralPeer + peersLock sync.Mutex + timer *time.Timer +} + +// NewEphemeralManager instantiate new EphemeralManager +func NewEphemeralManager(store Store, accountManager AccountManager) *EphemeralManager { + return &EphemeralManager{ + store: store, + accountManager: accountManager, + } +} + +// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head +// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new +// head. +func (e *EphemeralManager) LoadInitialPeers() { + e.peersLock.Lock() + defer e.peersLock.Unlock() + + e.loadEphemeralPeers() + if e.headPeer != nil { + e.timer = time.AfterFunc(ephemeralLifeTime, e.cleanup) + } +} + +// Stop timer +func (e *EphemeralManager) Stop() { + e.peersLock.Lock() + defer e.peersLock.Unlock() + + if e.timer != nil { + e.timer.Stop() + } +} + +// OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer +// is active the manager will not delete it while it is active. +func (e *EphemeralManager) OnPeerConnected(peer *Peer) { + if !peer.Ephemeral { + return + } + + log.Tracef("remove peer from ephemeral list: %s", peer.ID) + + e.peersLock.Lock() + defer e.peersLock.Unlock() + + e.removePeer(peer.ID) + + // stop the unnecessary timer + if e.headPeer == nil && e.timer != nil { + e.timer.Stop() + e.timer = nil + } +} + +// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer +// is inactive it will be deleted after the ephemeralLifeTime period. +func (e *EphemeralManager) OnPeerDisconnected(peer *Peer) { + if !peer.Ephemeral { + return + } + + log.Tracef("add peer to ephemeral list: %s", peer.ID) + + a, err := e.store.GetAccountByPeerID(peer.ID) + if err != nil { + log.Errorf("failed to add peer to ephemeral list: %s", err) + return + } + + e.peersLock.Lock() + defer e.peersLock.Unlock() + + if e.isPeerOnList(peer.ID) { + return + } + + e.addPeer(peer.ID, a, newDeadLine()) + if e.timer == nil { + e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup) + } +} + +func (e *EphemeralManager) loadEphemeralPeers() { + accounts := e.store.GetAllAccounts() + t := newDeadLine() + count := 0 + for _, a := range accounts { + for id, p := range a.Peers { + if p.Ephemeral { + count++ + e.addPeer(id, a, t) + } + } + } + log.Debugf("loaded ephemeral peer(s): %d", count) +} + +func (e *EphemeralManager) cleanup() { + log.Tracef("on ephemeral cleanup") + deletePeers := make(map[string]*ephemeralPeer) + + e.peersLock.Lock() + now := timeNow() + for p := e.headPeer; p != nil; p = p.next { + if now.Before(p.deadline) { + break + } + + deletePeers[p.id] = p + e.headPeer = p.next + if p.next == nil { + e.tailPeer = nil + } + } + + if e.headPeer != nil { + e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup) + } else { + e.timer = nil + } + + e.peersLock.Unlock() + + for id, p := range deletePeers { + log.Debugf("delete ephemeral peer: %s", id) + err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator) + if err != nil { + log.Tracef("failed to delete ephemeral peer: %s", err) + } + } +} + +func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) { + ep := &ephemeralPeer{ + id: id, + account: account, + deadline: deadline, + } + + if e.headPeer == nil { + e.headPeer = ep + } + if e.tailPeer != nil { + e.tailPeer.next = ep + } + e.tailPeer = ep +} + +func (e *EphemeralManager) removePeer(id string) { + if e.headPeer == nil { + return + } + + if e.headPeer.id == id { + e.headPeer = e.headPeer.next + if e.tailPeer.id == id { + e.tailPeer = nil + } + return + } + + for p := e.headPeer; p.next != nil; p = p.next { + if p.next.id == id { + // if we remove the last element from the chain then set the last-1 as tail + if e.tailPeer.id == id { + e.tailPeer = p + } + p.next = p.next.next + return + } + } +} + +func (e *EphemeralManager) isPeerOnList(id string) bool { + for p := e.headPeer; p != nil; p = p.next { + if p.id == id { + return true + } + } + return false +} + +func newDeadLine() time.Time { + return timeNow().Add(ephemeralLifeTime) +} diff --git a/management/server/management_refactor/server/peers/peer.go b/management/server/management_refactor/server/peers/peer.go new file mode 100644 index 000000000..ec28f5208 --- /dev/null +++ b/management/server/management_refactor/server/peers/peer.go @@ -0,0 +1,186 @@ +package peers + +import ( + "fmt" + "net" + "time" + + log "github.com/sirupsen/logrus" +) + +// Peer represents a machine connected to the network. +// The Peer is a WireGuard peer identified by a public key +type Peer struct { + // ID is an internal ID of the peer + ID string `gorm:"primaryKey"` + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index;uniqueIndex:idx_peers_account_id_ip"` + // WireGuard public key + Key string `gorm:"index"` + // A setup key this peer was registered with + SetupKey string + // IP address of the Peer + IP net.IP `gorm:"uniqueIndex:idx_peers_account_id_ip"` + // Meta is a Peer system meta data + Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"` + // Name is peer's name (machine name) + Name string + // DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's + // domain to the peer label. e.g. peer-dns-label.netbird.cloud + DNSLabel string + // Status peer's management connection status + Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"` + // The user ID that registered the peer + UserID string + // SSHKey is a public SSH key of the peer + SSHKey string + // SSHEnabled indicates whether SSH server is enabled on the peer + SSHEnabled bool + // LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login. + // Works with LastLogin + LoginExpirationEnabled bool + // LastLogin the time when peer performed last login operation + LastLogin time.Time + // Indicate ephemeral peer attribute + Ephemeral bool +} + +// PeerSystemMeta is a metadata of a Peer machine system +type PeerSystemMeta struct { + Hostname string + GoOS string + Kernel string + Core string + Platform string + OS string + WtVersion string + UIVersion string +} + +func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool { + return p.Hostname == other.Hostname && + p.GoOS == other.GoOS && + p.Kernel == other.Kernel && + p.Core == other.Core && + p.Platform == other.Platform && + p.OS == other.OS && + p.WtVersion == other.WtVersion && + p.UIVersion == other.UIVersion +} + +// AddedWithSSOLogin indicates whether this peer has been added with an SSO login by a user. +func (p *Peer) AddedWithSSOLogin() bool { + return p.UserID != "" +} + +// Copy copies Peer object +func (p *Peer) Copy() *Peer { + peerStatus := p.Status + if peerStatus != nil { + peerStatus = p.Status.Copy() + } + return &Peer{ + ID: p.ID, + AccountID: p.AccountID, + Key: p.Key, + SetupKey: p.SetupKey, + IP: p.IP, + Meta: p.Meta, + Name: p.Name, + DNSLabel: p.DNSLabel, + Status: peerStatus, + UserID: p.UserID, + SSHKey: p.SSHKey, + SSHEnabled: p.SSHEnabled, + LoginExpirationEnabled: p.LoginExpirationEnabled, + LastLogin: p.LastLogin, + Ephemeral: p.Ephemeral, + } +} + +// UpdateMetaIfNew updates peer's system metadata if new information is provided +// returns true if meta was updated, false otherwise +func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) bool { + // Avoid overwriting UIVersion if the update was triggered sole by the CLI client + if meta.UIVersion == "" { + meta.UIVersion = p.Meta.UIVersion + } + + if p.Meta.isEqual(meta) { + return false + } + p.Meta = meta + return true +} + +// MarkLoginExpired marks peer's status expired or not +func (p *Peer) MarkLoginExpired(expired bool) { + newStatus := p.Status.Copy() + newStatus.LoginExpired = expired + if expired { + newStatus.Connected = false + } + p.Status = newStatus +} + +// LoginExpired indicates whether the peer's login has expired or not. +// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired. +// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired). +// Login expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property. +// Login expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled. +// Only peers added by interactive SSO login can be expired. +func (p *Peer) LoginExpired(expiresIn time.Duration) (bool, time.Duration) { + if !p.AddedWithSSOLogin() || !p.LoginExpirationEnabled { + return false, 0 + } + expiresAt := p.LastLogin.Add(expiresIn) + now := time.Now() + timeLeft := expiresAt.Sub(now) + return timeLeft <= 0, timeLeft +} + +// FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain +func (p *Peer) FQDN(dnsDomain string) string { + if dnsDomain == "" { + return "" + } + return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain) +} + +// EventMeta returns activity event meta related to the peer +func (p *Peer) EventMeta(dnsDomain string) map[string]any { + return map[string]any{"name": p.Name, "fqdn": p.FQDN(dnsDomain), "ip": p.IP} +} + +// Copy PeerStatus +func (p *PeerStatus) Copy() *PeerStatus { + return &PeerStatus{ + LastSeen: p.LastSeen, + Connected: p.Connected, + LoginExpired: p.LoginExpired, + } +} + +// UpdateLastLogin and set login expired false +func (p *Peer) UpdateLastLogin() { + p.LastLogin = time.Now().UTC() + newStatus := p.Status.Copy() + newStatus.LoginExpired = false + p.Status = newStatus +} + +func (p *Peer) CheckAndUpdatePeerSSHKey(newSSHKey string) bool { + if len(newSSHKey) == 0 { + log.Debugf("no new SSH key provided for peer %s, skipping update", p.ID) + return false + } + + if p.SSHKey == newSSHKey { + log.Debugf("same SSH key provided for peer %s, skipping update", p.ID) + return false + } + + p.SSHKey = newSSHKey + + return true +} diff --git a/management/server/management_refactor/server/peers/peer_repository.go b/management/server/management_refactor/server/peers/peer_repository.go new file mode 100644 index 000000000..9229db8dd --- /dev/null +++ b/management/server/management_refactor/server/peers/peer_repository.go @@ -0,0 +1,6 @@ +package peers + +type PeerRepository interface { + findPeerByPubKey(pubKey string) (Peer, error) + updatePeer(peer Peer) error +} diff --git a/management/server/management_refactor/server/peers/peers_manager.go b/management/server/management_refactor/server/peers/peers_manager.go new file mode 100644 index 000000000..8888835f3 --- /dev/null +++ b/management/server/management_refactor/server/peers/peers_manager.go @@ -0,0 +1,214 @@ +package peers + +import ( + "time" + + log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/management_refactor/server/accounts" + "github.com/netbirdio/netbird/management/server/management_refactor/server/events" + "github.com/netbirdio/netbird/management/server/management_refactor/server/users" + "github.com/netbirdio/netbird/management/server/status" +) + +// PeerLogin used as a data object between the gRPC API and AccountManager on Login request. +type PeerLogin struct { + // WireGuardPubKey is a peers WireGuard public key + WireGuardPubKey string + // SSHKey is a peer's ssh key. Can be empty (e.g., old version do not provide it, or this feature is disabled) + SSHKey string + // Meta is the system information passed by peer, must be always present. + Meta PeerSystemMeta + // UserID indicates that JWT was used to log in, and it was valid. Can be empty when SetupKey is used or auth is not required. + UserID string + // AccountID indicates that JWT was used to log in, and it was valid. Can be empty when SetupKey is used or auth is not required. + AccountID string + // SetupKey references to a server.SetupKey to log in. Can be empty when UserID is used or auth is not required. + SetupKey string +} + +type PeerStatus struct { + // LastSeen is the last time peer was connected to the management service + LastSeen time.Time + // Connected indicates whether peer is connected to the management service or not + Connected bool + // LoginExpired + LoginExpired bool +} + +// PeerSync used as a data object between the gRPC API and AccountManager on Sync request. +type PeerSync struct { + // WireGuardPubKey is a peers WireGuard public key + WireGuardPubKey string +} + +type PeersManager interface { + LoginPeer(login PeerLogin) (*Peer, *accounts.NetworkMap, error) +} + +type DefaultPeersManager struct { + repository PeerRepository + userManager users.UserManager + accountManager accounts.AccountManager + eventsManager events.EventsManager +} + +// LoginPeer logs in or registers a peer. +// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. +func (pm *DefaultPeersManager) LoginPeer(login PeerLogin) (*Peer, *accounts.NetworkMap, error) { + + peer, err := pm.repository.findPeerByPubKey(login.WireGuardPubKey) + if err != nil { + return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered") + } + + if peer.AddedWithSSOLogin() { + user, err := pm.userManager.GetUser(peer.UserID) + if err != nil { + return nil, nil, err + } + if user.IsBlocked() { + return nil, nil, status.Errorf(status.PermissionDenied, "user is blocked") + } + } + + account, err := pm.accountManager.GetAccount(peer.AccountID) + if err != nil { + return nil, nil, err + } + + // this flag prevents unnecessary calls to the persistent store. + shouldStorePeer := false + updateRemotePeers := false + if peerLoginExpired(peer, account) { + err = checkAuth(login.UserID, peer) + if err != nil { + return nil, nil, err + } + // If peer was expired before and if it reached this point, it is re-authenticated. + // UserID is present, meaning that JWT validation passed successfully in the API layer. + peer.UpdateLastLogin() + updateRemotePeers = true + shouldStorePeer = true + + pm.eventsManager.StoreEvent(login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(pm.accountManager.GetDNSDomain())) + } + + if peer.UpdateMetaIfNew(login.Meta) { + shouldStorePeer = true + } + + if peer.CheckAndUpdatePeerSSHKey(login.SSHKey) { + shouldStorePeer = true + } + + if shouldStorePeer { + err := pm.repository.updatePeer(peer) + if err != nil { + return nil, nil, err + } + } + + if updateRemotePeers { + am.updateAccountPeers(account) + } + return peer, account.GetPeerNetworkMap(peer.ID, pm.accountManager.GetDNSDomain()), nil +} + +// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible +func (pm *DefaultPeersManager) SyncPeer(sync PeerSync) (*Peer, *accounts.NetworkMap, error) { + // we found the peer, and we follow a normal login flow + // unlock := am.Store.AcquireAccountLock(account.Id) + // defer unlock() + + peer, err := pm.repository.findPeerByPubKey(sync.WireGuardPubKey) + if err != nil { + return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered") + } + + if peer.AddedWithSSOLogin() { + user, err := pm.userManager.GetUser(peer.UserID) + if err != nil { + return nil, nil, err + } + if user.IsBlocked() { + return nil, nil, status.Errorf(status.PermissionDenied, "user is blocked") + } + } + + account, err := pm.accountManager.GetAccount(peer.AccountID) + if err != nil { + return nil, nil, err + } + + if peerLoginExpired(peer, account) { + return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") + } + + return &peer, account.GetPeerNetworkMap(peer.ID, pm.accountManager.GetDNSDomain()), nil +} + +func (pm *DefaultPeersManager) GetNetworkMap(peerID string, dnsDomain string) (*accounts.NetworkMap, error) { + aclPeers, firewallRules := a.getPeerConnectionResources(peerID) + // exclude expired peers + var peersToConnect []*Peer + var expiredPeers []*Peer + for _, p := range aclPeers { + expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration) + if a.Settings.PeerLoginExpirationEnabled && expired { + expiredPeers = append(expiredPeers, p) + continue + } + peersToConnect = append(peersToConnect, p) + } + + routesUpdate := a.getRoutesToSync(peerID, peersToConnect) + + dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) + dnsUpdate := nbdns.Config{ + ServiceEnable: dnsManagementStatus, + } + + if dnsManagementStatus { + var zones []nbdns.CustomZone + peersCustomZone := getPeersCustomZone(a, dnsDomain) + if peersCustomZone.Domain != "" { + zones = append(zones, peersCustomZone) + } + dnsUpdate.CustomZones = zones + dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) + } + + return &NetworkMap{ + Peers: peersToConnect, + Network: a.Network.Copy(), + Routes: routesUpdate, + DNSConfig: dnsUpdate, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, + } +} + +func peerLoginExpired(peer Peer, account accounts.Account) bool { + expired, expiresIn := peer.LoginExpired(account.Settings.PeerLoginExpiration) + expired = account.Settings.PeerLoginExpirationEnabled && expired + if expired || peer.Status.LoginExpired { + log.Debugf("peer's %s login expired %v ago", peer.ID, expiresIn) + return true + } + return false +} + +func checkAuth(loginUserID string, peer Peer) error { + if loginUserID == "" { + // absence of a user ID indicates that JWT wasn't provided. + return status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") + } + if peer.UserID != loginUserID { + log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID) + return status.Errorf(status.Unauthenticated, "can't login") + } + return nil +} diff --git a/management/server/management_refactor/server/peers/updatechannel.go b/management/server/management_refactor/server/peers/updatechannel.go new file mode 100644 index 000000000..ce70dd844 --- /dev/null +++ b/management/server/management_refactor/server/peers/updatechannel.go @@ -0,0 +1,153 @@ +package peers + +import ( + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/telemetry" +) + +const channelBufferSize = 100 + +type UpdateMessage struct { + Update *proto.SyncResponse +} + +type PeersUpdateManager struct { + // PeerChannels is an update channel indexed by Peer.ID + PeerChannels map[string]chan *UpdateMessage + // channelsMux keeps the mutex to access PeerChannels + channelsMux *sync.Mutex + // metrics provides method to collect application metrics + metrics telemetry.AppMetrics +} + +// NewPeersUpdateManager returns a new instance of PeersUpdateManager +func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { + return &PeersUpdateManager{ + PeerChannels: make(map[string]chan *UpdateMessage), + channelsMux: &sync.Mutex{}, + metrics: metrics, + } +} + +// SendUpdate sends update message to the peer's channel +func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) { + start := time.Now() + var found, dropped bool + + p.channelsMux.Lock() + defer func() { + p.channelsMux.Unlock() + if p.metrics != nil { + p.metrics.UpdateChannelMetrics().CountSendUpdateDuration(time.Since(start), found, dropped) + } + }() + + if channel, ok := p.PeerChannels[peerID]; ok { + found = true + select { + case channel <- update: + log.Debugf("update was sent to channel for peer %s", peerID) + default: + dropped = true + log.Warnf("channel for peer %s is %d full", peerID, len(channel)) + } + } else { + log.Debugf("peer %s has no channel", peerID) + } +} + +// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer. +func (p *PeersUpdateManager) CreateChannel(peerID string) chan *UpdateMessage { + start := time.Now() + + closed := false + + p.channelsMux.Lock() + defer func() { + p.channelsMux.Unlock() + if p.metrics != nil { + p.metrics.UpdateChannelMetrics().CountCreateChannelDuration(time.Since(start), closed) + } + }() + + if channel, ok := p.PeerChannels[peerID]; ok { + closed = true + delete(p.PeerChannels, peerID) + close(channel) + } + // mbragin: todo shouldn't it be more? or configurable? + channel := make(chan *UpdateMessage, channelBufferSize) + p.PeerChannels[peerID] = channel + + log.Debugf("opened updates channel for a peer %s", peerID) + + return channel +} + +func (p *PeersUpdateManager) closeChannel(peerID string) { + if channel, ok := p.PeerChannels[peerID]; ok { + delete(p.PeerChannels, peerID) + close(channel) + } + + log.Debugf("closed updates channel of a peer %s", peerID) +} + +// CloseChannels closes updates channel for each given peer +func (p *PeersUpdateManager) CloseChannels(peerIDs []string) { + start := time.Now() + + p.channelsMux.Lock() + defer func() { + p.channelsMux.Unlock() + if p.metrics != nil { + p.metrics.UpdateChannelMetrics().CountCloseChannelsDuration(time.Since(start), len(peerIDs)) + } + }() + + for _, id := range peerIDs { + p.closeChannel(id) + } +} + +// CloseChannel closes updates channel of a given peer +func (p *PeersUpdateManager) CloseChannel(peerID string) { + start := time.Now() + + p.channelsMux.Lock() + defer func() { + p.channelsMux.Unlock() + if p.metrics != nil { + p.metrics.UpdateChannelMetrics().CountCloseChannelDuration(time.Since(start)) + } + }() + + p.closeChannel(peerID) +} + +// GetAllConnectedPeers returns a copy of the connected peers map +func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} { + start := time.Now() + + p.channelsMux.Lock() + + m := make(map[string]struct{}) + + defer func() { + p.channelsMux.Unlock() + if p.metrics != nil { + p.metrics.UpdateChannelMetrics().CountGetAllConnectedPeersDuration(time.Since(start), len(m)) + } + }() + + for ID := range p.PeerChannels { + m[ID] = struct{}{} + } + + return m +} diff --git a/management/server/management_refactor/server/routes/route.go b/management/server/management_refactor/server/routes/route.go new file mode 100644 index 000000000..0db51ae52 --- /dev/null +++ b/management/server/management_refactor/server/routes/route.go @@ -0,0 +1 @@ +package routes diff --git a/management/server/management_refactor/server/scheduler.go b/management/server/management_refactor/server/scheduler.go new file mode 100644 index 000000000..95d393e7d --- /dev/null +++ b/management/server/management_refactor/server/scheduler.go @@ -0,0 +1,115 @@ +package server + +import ( + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +// Scheduler is an interface which implementations can schedule and cancel jobs +type Scheduler interface { + Cancel(IDs []string) + Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) +} + +// MockScheduler is a mock implementation of Scheduler +type MockScheduler struct { + CancelFunc func(IDs []string) + ScheduleFunc func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) +} + +// Cancel mocks the Cancel function of the Scheduler interface +func (mock *MockScheduler) Cancel(IDs []string) { + if mock.CancelFunc != nil { + mock.CancelFunc(IDs) + return + } + log.Errorf("MockScheduler doesn't have Cancel function defined ") +} + +// Schedule mocks the Schedule function of the Scheduler interface +func (mock *MockScheduler) Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { + if mock.ScheduleFunc != nil { + mock.ScheduleFunc(in, ID, job) + return + } + log.Errorf("MockScheduler doesn't have Schedule function defined") +} + +// DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them. +type DefaultScheduler struct { + // jobs map holds cancellation channels indexed by the job ID + jobs map[string]chan struct{} + mu *sync.Mutex +} + +// NewDefaultScheduler creates an instance of a DefaultScheduler +func NewDefaultScheduler() *DefaultScheduler { + return &DefaultScheduler{ + jobs: make(map[string]chan struct{}), + mu: &sync.Mutex{}, + } +} + +func (wm *DefaultScheduler) cancel(ID string) bool { + cancel, ok := wm.jobs[ID] + if ok { + delete(wm.jobs, ID) + select { + case cancel <- struct{}{}: + log.Debugf("cancelled scheduled job %s", ID) + default: + log.Warnf("couldn't cancel job %s because there was no routine listening on the cancel event", ID) + return false + } + + } + return ok +} + +// Cancel cancels the scheduled job by ID if present. +// If job wasn't found the function returns false. +func (wm *DefaultScheduler) Cancel(IDs []string) { + wm.mu.Lock() + defer wm.mu.Unlock() + + for _, id := range IDs { + wm.cancel(id) + } +} + +// Schedule a job to run in some time in the future. If job returns true then it will be scheduled one more time. +// If job with the provided ID already exists, a new one won't be scheduled. +func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { + wm.mu.Lock() + defer wm.mu.Unlock() + cancel := make(chan struct{}) + if _, ok := wm.jobs[ID]; ok { + log.Debugf("couldn't schedule a job %s because it already exists. There are %d total jobs scheduled.", + ID, len(wm.jobs)) + return + } + + wm.jobs[ID] = cancel + log.Debugf("scheduled a job %s to run in %s. There are %d total jobs scheduled.", ID, in.String(), len(wm.jobs)) + go func() { + select { + case <-time.After(in): + log.Debugf("time to do a scheduled job %s", ID) + runIn, reschedule := job() + wm.mu.Lock() + defer wm.mu.Unlock() + delete(wm.jobs, ID) + if reschedule { + go wm.Schedule(runIn, ID, job) + } + case <-cancel: + log.Debugf("stopped scheduled job %s ", ID) + wm.mu.Lock() + defer wm.mu.Unlock() + delete(wm.jobs, ID) + return + } + }() +} diff --git a/management/server/management_refactor/server/setupkey/setupkey.go b/management/server/management_refactor/server/setupkey/setupkey.go new file mode 100644 index 000000000..ddd9196df --- /dev/null +++ b/management/server/management_refactor/server/setupkey/setupkey.go @@ -0,0 +1 @@ +package setupkey diff --git a/management/server/management_refactor/server/store/sqlite_store.go b/management/server/management_refactor/server/store/sqlite_store.go new file mode 100644 index 000000000..704bce590 --- /dev/null +++ b/management/server/management_refactor/server/store/sqlite_store.go @@ -0,0 +1,459 @@ +package store + +import ( + "path/filepath" + "runtime" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/route" +) + +// SqliteStore represents an account storage backed by a Sqlite DB persisted to disk +type SqliteStore struct { + db *gorm.DB + storeFile string + accountLocks sync.Map + globalAccountLock sync.Mutex + metrics telemetry.AppMetrics + installationPK int +} + +type installation struct { + ID uint `gorm:"primaryKey"` + InstallationIDValue string +} + +// NewSqliteStore restores a store from the file located in the datadir +func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) { + storeStr := "store.db?cache=shared" + if runtime.GOOS == "windows" { + // Vo avoid `The process cannot access the file because it is being used by another process` on Windows + storeStr = "store.db" + } + + file := filepath.Join(dataDir, storeStr) + db, err := gorm.Open(sqlite.Open(file), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + PrepareStmt: true, + }) + if err != nil { + return nil, err + } + + sql, err := db.DB() + if err != nil { + return nil, err + } + conns := runtime.NumCPU() + sql.SetMaxOpenConns(conns) // TODO: make it configurable + + err = db.AutoMigrate( + &SetupKey{}, &Peer{}, &User{}, &PersonalAccessToken{}, &Group{}, &Rule{}, + &Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, + &installation{}, + ) + if err != nil { + return nil, err + } + + return &SqliteStore{db: db, storeFile: file, metrics: metrics, installationPK: 1}, nil +} + +// NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir +func NewSqliteStoreFromFileStore(filestore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) { + store, err := NewSqliteStore(dataDir, metrics) + if err != nil { + return nil, err + } + + err = store.SaveInstallationID(filestore.InstallationID) + if err != nil { + return nil, err + } + + for _, account := range filestore.GetAllAccounts() { + err := store.SaveAccount(account) + if err != nil { + return nil, err + } + } + + return store, nil +} + +// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock +func (s *SqliteStore) AcquireGlobalLock() (unlock func()) { + log.Debugf("acquiring global lock") + start := time.Now() + s.globalAccountLock.Lock() + + unlock = func() { + s.globalAccountLock.Unlock() + log.Debugf("released global lock in %v", time.Since(start)) + } + + took := time.Since(start) + log.Debugf("took %v to acquire global lock", took) + if s.metrics != nil { + s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took) + } + + return unlock +} + +func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) { + log.Debugf("acquiring lock for account %s", accountID) + + start := time.Now() + value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) + mtx := value.(*sync.Mutex) + mtx.Lock() + + unlock = func() { + mtx.Unlock() + log.Debugf("released lock for account %s in %v", accountID, time.Since(start)) + } + + return unlock +} + +func (s *SqliteStore) SaveAccount(account *Account) error { + start := time.Now() + + for _, key := range account.SetupKeys { + account.SetupKeysG = append(account.SetupKeysG, *key) + } + + for id, peer := range account.Peers { + peer.ID = id + account.PeersG = append(account.PeersG, *peer) + } + + for id, user := range account.Users { + user.Id = id + for id, pat := range user.PATs { + pat.ID = id + user.PATsG = append(user.PATsG, *pat) + } + account.UsersG = append(account.UsersG, *user) + } + + for id, group := range account.Groups { + group.ID = id + account.GroupsG = append(account.GroupsG, *group) + } + + for id, rule := range account.Rules { + rule.ID = id + account.RulesG = append(account.RulesG, *rule) + } + + for id, route := range account.Routes { + route.ID = id + account.RoutesG = append(account.RoutesG, *route) + } + + for id, ns := range account.NameServerGroups { + ns.ID = id + account.NameServerGroupsG = append(account.NameServerGroupsG, *ns) + } + + err := s.db.Transaction(func(tx *gorm.DB) error { + result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) + if result.Error != nil { + return result.Error + } + + result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id) + if result.Error != nil { + return result.Error + } + + result = tx.Select(clause.Associations).Delete(account) + if result.Error != nil { + return result.Error + } + + result = tx. + Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.OnConflict{UpdateAll: true}).Create(account) + if result.Error != nil { + return result.Error + } + return nil + }) + + took := time.Since(start) + if s.metrics != nil { + s.metrics.StoreMetrics().CountPersistenceDuration(took) + } + log.Debugf("took %d ms to persist an account to the SQLite", took.Milliseconds()) + + return err +} + +func (s *SqliteStore) SaveInstallationID(ID string) error { + installation := installation{InstallationIDValue: ID} + installation.ID = uint(s.installationPK) + + return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&installation).Error +} + +func (s *SqliteStore) GetInstallationID() string { + var installation installation + + if result := s.db.First(&installation, "id = ?", s.installationPK); result.Error != nil { + return "" + } + + return installation.InstallationIDValue +} + +func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus PeerStatus) error { + var peer Peer + + result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID) + if result.Error != nil { + return status.Errorf(status.NotFound, "peer %s not found", peerID) + } + + peer.Status = &peerStatus + + return s.db.Save(peer).Error +} + +// DeleteHashedPAT2TokenIDIndex is noop in Sqlite +func (s *SqliteStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error { + return nil +} + +// DeleteTokenID2UserIDIndex is noop in Sqlite +func (s *SqliteStore) DeleteTokenID2UserIDIndex(tokenID string) error { + return nil +} + +func (s *SqliteStore) GetAccountByPrivateDomain(domain string) (*Account, error) { + var account Account + + result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?", + strings.ToLower(domain), true, PrivateCategory) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") + } + + // TODO: rework to not call GetAccount + return s.GetAccount(account.Id) +} + +func (s *SqliteStore) GetAccountBySetupKey(setupKey string) (*Account, error) { + var key SetupKey + result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey)) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + if key.AccountID == "" { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + return s.GetAccount(key.AccountID) +} + +func (s *SqliteStore) GetTokenIDByHashedToken(hashedToken string) (string, error) { + var token PersonalAccessToken + result := s.db.First(&token, "hashed_token = ?", hashedToken) + if result.Error != nil { + return "", status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + return token.ID, nil +} + +func (s *SqliteStore) GetUserByTokenID(tokenID string) (*User, error) { + var token PersonalAccessToken + result := s.db.First(&token, "id = ?", tokenID) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + if token.UserID == "" { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + var user User + result = s.db.Preload("PATsG").First(&user, "id = ?", token.UserID) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + user.PATs = make(map[string]*PersonalAccessToken, len(user.PATsG)) + for _, pat := range user.PATsG { + user.PATs[pat.ID] = pat.Copy() + } + + return &user, nil +} + +func (s *SqliteStore) GetAllAccounts() (all []*Account) { + var accounts []Account + result := s.db.Find(&accounts) + if result.Error != nil { + return all + } + + for _, account := range accounts { + if acc, err := s.GetAccount(account.Id); err == nil { + all = append(all, acc) + } + } + + return all +} + +func (s *SqliteStore) GetAccount(accountID string) (*Account, error) { + var account Account + + result := s.db.Model(&account). + Preload("UsersG.PATsG"). // have to be specifies as this is nester reference + Preload(clause.Associations). + First(&account, "id = ?", accountID) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found") + } + + // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us + for i, policy := range account.Policies { + var rules []*PolicyRule + err := s.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + if err != nil { + return nil, status.Errorf(status.NotFound, "account not found") + } + account.Policies[i].Rules = rules + } + + account.SetupKeys = make(map[string]*SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + account.SetupKeys[key.Key] = key.Copy() + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = peer.Copy() + } + account.PeersG = nil + + account.Users = make(map[string]*User, len(account.UsersG)) + for _, user := range account.UsersG { + user.PATs = make(map[string]*PersonalAccessToken, len(user.PATs)) + for _, pat := range user.PATsG { + user.PATs[pat.ID] = pat.Copy() + } + account.Users[user.Id] = user.Copy() + } + account.UsersG = nil + + account.Groups = make(map[string]*Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + account.Groups[group.ID] = group.Copy() + } + account.GroupsG = nil + + account.Rules = make(map[string]*Rule, len(account.RulesG)) + for _, rule := range account.RulesG { + account.Rules[rule.ID] = rule.Copy() + } + account.RulesG = nil + + account.Routes = make(map[string]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = route.Copy() + } + account.RoutesG = nil + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + account.NameServerGroups[ns.ID] = ns.Copy() + } + account.NameServerGroupsG = nil + + return &account, nil +} + +func (s *SqliteStore) GetAccountByUser(userID string) (*Account, error) { + var user User + result := s.db.Select("account_id").First(&user, "id = ?", userID) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + if user.AccountID == "" { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + return s.GetAccount(user.AccountID) +} + +func (s *SqliteStore) GetAccountByPeerID(peerID string) (*Account, error) { + var peer Peer + result := s.db.Select("account_id").First(&peer, "id = ?", peerID) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + if peer.AccountID == "" { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + return s.GetAccount(peer.AccountID) +} + +func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { + var peer Peer + + result := s.db.Select("account_id").First(&peer, "key = ?", peerKey) + if result.Error != nil { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + if peer.AccountID == "" { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + + return s.GetAccount(peer.AccountID) +} + +// SaveUserLastLogin stores the last login time for a user in DB. +func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { + var user User + + result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID) + if result.Error != nil { + return status.Errorf(status.NotFound, "user %s not found", userID) + } + + user.LastLogin = lastLogin + + return s.db.Save(user).Error +} + +// Close is noop in Sqlite +func (s *SqliteStore) Close() error { + return nil +} + +// GetStoreEngine returns SqliteStoreEngine +func (s *SqliteStore) GetStoreEngine() StoreEngine { + return SqliteStoreEngine +} diff --git a/management/server/management_refactor/server/store/store.go b/management/server/management_refactor/server/store/store.go new file mode 100644 index 000000000..4583e9f00 --- /dev/null +++ b/management/server/management_refactor/server/store/store.go @@ -0,0 +1,98 @@ +package store + +import ( + "fmt" + "os" + "strings" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/telemetry" +) + +type Store interface { + GetAllAccounts() []*Account + GetAccount(accountID string) (*Account, error) + GetAccountByUser(userID string) (*Account, error) + GetAccountByPeerPubKey(peerKey string) (*Account, error) + GetAccountByPeerID(peerID string) (*Account, error) + GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later + GetAccountByPrivateDomain(domain string) (*Account, error) + GetTokenIDByHashedToken(secret string) (string, error) + GetUserByTokenID(tokenID string) (*User, error) + SaveAccount(account *Account) error + DeleteHashedPAT2TokenIDIndex(hashedToken string) error + DeleteTokenID2UserIDIndex(tokenID string) error + GetInstallationID() string + SaveInstallationID(ID string) error + // AcquireAccountLock should attempt to acquire account lock and return a function that releases the lock + AcquireAccountLock(accountID string) func() + // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock + AcquireGlobalLock() func() + SavePeerStatus(accountID, peerID string, status PeerStatus) error + SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error + // Close should close the store persisting all unsaved data. + Close() error + // GetStoreEngine should return StoreEngine of the current store implementation. + // This is also a method of metrics.DataSource interface. + GetStoreEngine() StoreEngine +} + +type StoreEngine string + +const ( + FileStoreEngine StoreEngine = "jsonfile" + SqliteStoreEngine StoreEngine = "sqlite" +) + +func getStoreEngineFromEnv() StoreEngine { + // NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise rely on the config file. + kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE") + if !ok { + return FileStoreEngine + } + + value := StoreEngine(strings.ToLower(kind)) + + if value == FileStoreEngine || value == SqliteStoreEngine { + return value + } + + return FileStoreEngine +} + +func NewStore(kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (Store, error) { + if kind == "" { + // fallback to env. Normally this only should be used from tests + kind = getStoreEngineFromEnv() + } + switch kind { + case FileStoreEngine: + log.Info("using JSON file store engine") + return NewFileStore(dataDir, metrics) + case SqliteStoreEngine: + log.Info("using SQLite store engine") + return NewSqliteStore(dataDir, metrics) + default: + return nil, fmt.Errorf("unsupported kind of store %s", kind) + } +} + +func NewStoreFromJson(dataDir string, metrics telemetry.AppMetrics) (Store, error) { + fstore, err := NewFileStore(dataDir, nil) + if err != nil { + return nil, err + } + + kind := getStoreEngineFromEnv() + + switch kind { + case FileStoreEngine: + return fstore, nil + case SqliteStoreEngine: + return NewSqliteStoreFromFileStore(fstore, dataDir, metrics) + default: + return nil, fmt.Errorf("unsupported store engine %s", kind) + } +} diff --git a/management/server/management_refactor/server/turncredentials.go b/management/server/management_refactor/server/turncredentials.go new file mode 100644 index 000000000..aedcf2ee1 --- /dev/null +++ b/management/server/management_refactor/server/turncredentials.go @@ -0,0 +1,125 @@ +package server + +import ( + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "fmt" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/proto" +) + +// TURNCredentialsManager used to manage TURN credentials +type TURNCredentialsManager interface { + GenerateCredentials() TURNCredentials + SetupRefresh(peerKey string) + CancelRefresh(peerKey string) +} + +// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server +type TimeBasedAuthSecretsManager struct { + mux sync.Mutex + config *TURNConfig + updateManager *PeersUpdateManager + cancelMap map[string]chan struct{} +} + +type TURNCredentials struct { + Username string + Password string +} + +func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, config *TURNConfig) *TimeBasedAuthSecretsManager { + return &TimeBasedAuthSecretsManager{ + mux: sync.Mutex{}, + config: config, + updateManager: updateManager, + cancelMap: make(map[string]chan struct{}), + } +} + +// GenerateCredentials generates new time-based secret credentials - basically username is a unix timestamp and password is a HMAC hash of a timestamp with a preshared TURN secret +func (m *TimeBasedAuthSecretsManager) GenerateCredentials() TURNCredentials { + mac := hmac.New(sha1.New, []byte(m.config.Secret)) + + timeAuth := time.Now().Add(m.config.CredentialsTTL.Duration).Unix() + + username := fmt.Sprint(timeAuth) + + _, err := mac.Write([]byte(username)) + if err != nil { + log.Errorln("Generating turn password failed with error: ", err) + } + + bytePassword := mac.Sum(nil) + password := base64.StdEncoding.EncodeToString(bytePassword) + + return TURNCredentials{ + Username: username, + Password: password, + } + +} + +func (m *TimeBasedAuthSecretsManager) cancel(peerID string) { + if channel, ok := m.cancelMap[peerID]; ok { + close(channel) + delete(m.cancelMap, peerID) + } +} + +// CancelRefresh cancels scheduled peer credentials refresh +func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) { + m.mux.Lock() + defer m.mux.Unlock() + m.cancel(peerID) +} + +// SetupRefresh starts peer credentials refresh. Since credentials are expiring (TTL) it is necessary to always generate them and send to the peer. +// A goroutine is created and put into TimeBasedAuthSecretsManager.cancelMap. This routine should be cancelled if peer is gone. +func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) { + m.mux.Lock() + defer m.mux.Unlock() + m.cancel(peerID) + cancel := make(chan struct{}, 1) + m.cancelMap[peerID] = cancel + log.Debugf("starting turn refresh for %s", peerID) + + go func() { + // we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL) + ticker := time.NewTicker(m.config.CredentialsTTL.Duration / 4 * 3) + + for { + select { + case <-cancel: + log.Debugf("stopping turn refresh for %s", peerID) + return + case <-ticker.C: + c := m.GenerateCredentials() + var turns []*proto.ProtectedHostConfig + for _, host := range m.config.Turns { + turns = append(turns, &proto.ProtectedHostConfig{ + HostConfig: &proto.HostConfig{ + Uri: host.URI, + Protocol: ToResponseProto(host.Proto), + }, + User: c.Username, + Password: c.Password, + }) + } + + update := &proto.SyncResponse{ + WiretrusteeConfig: &proto.WiretrusteeConfig{ + Turns: turns, + }, + } + log.Debugf("sending new TURN credentials to peer %s", peerID) + m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update}) + } + } + }() +} diff --git a/management/server/management_refactor/server/users/personal_access_token.go b/management/server/management_refactor/server/users/personal_access_token.go new file mode 100644 index 000000000..de28d56da --- /dev/null +++ b/management/server/management_refactor/server/users/personal_access_token.go @@ -0,0 +1,95 @@ +package users + +import ( + "crypto/sha256" + b64 "encoding/base64" + "fmt" + "hash/crc32" + "time" + + b "github.com/hashicorp/go-secure-stdlib/base62" + "github.com/rs/xid" + + "github.com/netbirdio/netbird/base62" +) + +const ( + // PATPrefix is the globally used, 4 char prefix for personal access tokens + PATPrefix = "nbp_" + // PATSecretLength number of characters used for the secret inside the token + PATSecretLength = 30 + // PATChecksumLength number of characters used for the encoded checksum of the secret inside the token + PATChecksumLength = 6 + // PATLength total number of characters used for the token + PATLength = 40 +) + +// PersonalAccessToken holds all information about a PAT including a hashed version of it for verification +type PersonalAccessToken struct { + ID string `gorm:"primaryKey"` + // User is a reference to Account that this object belongs + UserID string `gorm:"index"` + Name string + HashedToken string + ExpirationDate time.Time + // scope could be added in future + CreatedBy string + CreatedAt time.Time + LastUsed time.Time +} + +func (t *PersonalAccessToken) Copy() *PersonalAccessToken { + return &PersonalAccessToken{ + ID: t.ID, + Name: t.Name, + HashedToken: t.HashedToken, + ExpirationDate: t.ExpirationDate, + CreatedBy: t.CreatedBy, + CreatedAt: t.CreatedAt, + LastUsed: t.LastUsed, + } +} + +// PersonalAccessTokenGenerated holds the new PersonalAccessToken and the plain text version of it +type PersonalAccessTokenGenerated struct { + PlainToken string + PersonalAccessToken +} + +// CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User. +// Additionally, it will return the token in plain text once, to give to the user and only save a hashed version +func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) { + hashedToken, plainToken, err := generateNewToken() + if err != nil { + return nil, err + } + currentTime := time.Now() + return &PersonalAccessTokenGenerated{ + PersonalAccessToken: PersonalAccessToken{ + ID: xid.New().String(), + Name: name, + HashedToken: hashedToken, + ExpirationDate: currentTime.AddDate(0, 0, expirationInDays), + CreatedBy: createdBy, + CreatedAt: currentTime, + LastUsed: time.Time{}, + }, + PlainToken: plainToken, + }, nil + +} + +func generateNewToken() (string, string, error) { + secret, err := b.Random(PATSecretLength) + if err != nil { + return "", "", err + } + + checksum := crc32.ChecksumIEEE([]byte(secret)) + encodedChecksum := base62.Encode(checksum) + paddedChecksum := fmt.Sprintf("%06s", encodedChecksum) + plainToken := PATPrefix + secret + paddedChecksum + hashedToken := sha256.Sum256([]byte(plainToken)) + encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) + return encodedHashedToken, plainToken, nil +} diff --git a/management/server/management_refactor/server/users/user.go b/management/server/management_refactor/server/users/user.go new file mode 100644 index 000000000..8e244300b --- /dev/null +++ b/management/server/management_refactor/server/users/user.go @@ -0,0 +1,203 @@ +package users + +import ( + "fmt" + "strings" + "time" + + "github.com/netbirdio/netbird/management/server/idp" +) + +const ( + UserRoleAdmin UserRole = "admin" + UserRoleUser UserRole = "user" + UserRoleUnknown UserRole = "unknown" + + UserStatusActive UserStatus = "active" + UserStatusDisabled UserStatus = "disabled" + UserStatusInvited UserStatus = "invited" + + UserIssuedAPI = "api" + UserIssuedIntegration = "integration" +) + +type UserInfo struct { + ID string `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + Role string `json:"role"` + AutoGroups []string `json:"auto_groups"` + Status string `json:"-"` + IsServiceUser bool `json:"is_service_user"` + IsBlocked bool `json:"is_blocked"` + NonDeletable bool `json:"non_deletable"` + LastLogin time.Time `json:"last_login"` + Issued string `json:"issued"` + IntegrationReference IntegrationReference `json:"-"` +} + +// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown +func StrRoleToUserRole(strRole string) UserRole { + switch strings.ToLower(strRole) { + case "admin": + return UserRoleAdmin + case "user": + return UserRoleUser + default: + return UserRoleUnknown + } +} + +// UserStatus is the status of a User +type UserStatus string + +// UserRole is the role of a User +type UserRole string + +// IntegrationReference holds the reference to a particular integration +type IntegrationReference struct { + ID int + IntegrationType string +} + +func (ir IntegrationReference) String() string { + return fmt.Sprintf("%s:%d", ir.IntegrationType, ir.ID) +} + +func (ir IntegrationReference) CacheKey(path ...string) string { + if len(path) == 0 { + return ir.String() + } + return fmt.Sprintf("%s:%s", ir.String(), strings.Join(path, ":")) +} + +// User represents a user of the system +type User struct { + Id string `gorm:"primaryKey"` + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + Role UserRole + IsServiceUser bool + // NonDeletable indicates whether the service user can be deleted + NonDeletable bool + // ServiceUserName is only set if IsServiceUser is true + ServiceUserName string + // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user + AutoGroups []string `gorm:"serializer:json"` + PATs map[string]*PersonalAccessToken `gorm:"-"` + PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"` + // Blocked indicates whether the user is blocked. Blocked users can't use the system. + Blocked bool + // LastLogin is the last time the user logged in to IdP + LastLogin time.Time + + // Issued of the user + Issued string `gorm:"default:api"` + + IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` +} + +// IsBlocked returns true if the user is blocked, false otherwise +func (u *User) IsBlocked() bool { + return u.Blocked +} + +func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool { + return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero() +} + +// IsAdmin returns true if the user is an admin, false otherwise +func (u *User) IsAdmin() bool { + return u.Role == UserRoleAdmin +} + +// ToUserInfo converts a User object to a UserInfo object. +func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { + autoGroups := u.AutoGroups + if autoGroups == nil { + autoGroups = []string{} + } + + if userData == nil { + return &UserInfo{ + ID: u.Id, + Email: "", + Name: u.ServiceUserName, + Role: string(u.Role), + AutoGroups: u.AutoGroups, + Status: string(UserStatusActive), + IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, + LastLogin: u.LastLogin, + Issued: u.Issued, + }, nil + } + if userData.ID != u.Id { + return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id) + } + + userStatus := UserStatusActive + if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite { + userStatus = UserStatusInvited + } + + return &UserInfo{ + ID: u.Id, + Email: userData.Email, + Name: userData.Name, + Role: string(u.Role), + AutoGroups: autoGroups, + Status: string(userStatus), + IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, + LastLogin: u.LastLogin, + Issued: u.Issued, + }, nil +} + +// Copy the user +func (u *User) Copy() *User { + autoGroups := make([]string, len(u.AutoGroups)) + copy(autoGroups, u.AutoGroups) + pats := make(map[string]*PersonalAccessToken, len(u.PATs)) + for k, v := range u.PATs { + pats[k] = v.Copy() + } + return &User{ + Id: u.Id, + AccountID: u.AccountID, + Role: u.Role, + AutoGroups: autoGroups, + IsServiceUser: u.IsServiceUser, + NonDeletable: u.NonDeletable, + ServiceUserName: u.ServiceUserName, + PATs: pats, + Blocked: u.Blocked, + LastLogin: u.LastLogin, + Issued: u.Issued, + IntegrationReference: u.IntegrationReference, + } +} + +// NewUser creates a new user +func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User { + return &User{ + Id: id, + Role: role, + IsServiceUser: isServiceUser, + NonDeletable: nonDeletable, + ServiceUserName: serviceUserName, + AutoGroups: autoGroups, + Issued: issued, + } +} + +// NewRegularUser creates a new user with role UserRoleUser +func NewRegularUser(id string) *User { + return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI) +} + +// NewAdminUser creates a new user with role UserRoleAdmin +func NewAdminUser(id string) *User { + return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI) +} diff --git a/management/server/management_refactor/server/users/user_manager.go b/management/server/management_refactor/server/users/user_manager.go new file mode 100644 index 000000000..49a95c19b --- /dev/null +++ b/management/server/management_refactor/server/users/user_manager.go @@ -0,0 +1,19 @@ +package users + +type UserManager interface { + GetUser(userID string) (User, error) +} + +type DefaultUserManager struct { + repository UserRepository +} + +func NewUserManager(repository UserRepository) *DefaultUserManager { + return &DefaultUserManager{ + repository: repository, + } +} + +func (um *DefaultUserManager) GetUser(userID string) (User, error) { + return um.repository.findUserByID(userID) +} diff --git a/management/server/management_refactor/server/users/user_repository.go b/management/server/management_refactor/server/users/user_repository.go new file mode 100644 index 000000000..b10b5c5bc --- /dev/null +++ b/management/server/management_refactor/server/users/user_repository.go @@ -0,0 +1,5 @@ +package users + +type UserRepository interface { + findUserByID(userID string) (User, error) +}