starting refactor

This commit is contained in:
Pascal Fischer 2023-11-23 18:40:03 +01:00
parent 96cdcf8e49
commit 70c24384ea
30 changed files with 3938 additions and 0 deletions

View File

@ -0,0 +1,7 @@
package access_control
type AccessControlManager interface {
}
type DefaultAccessControlManager struct {
}

View File

@ -0,0 +1 @@
package access_control

View File

@ -0,0 +1,100 @@
package access_control
import "fmt"
// TrafficFlowType defines allowed direction of the traffic in the rule
type TrafficFlowType int
const (
// TrafficFlowBidirect allows traffic to both direction
TrafficFlowBidirect TrafficFlowType = iota
// TrafficFlowBidirectString allows traffic to both direction
TrafficFlowBidirectString = "bidirect"
// DefaultRuleName is a name for the Default rule that is created for every account
DefaultRuleName = "Default"
// DefaultRuleDescription is a description for the Default rule that is created for every account
DefaultRuleDescription = "This is a default rule that allows connections between all the resources"
// DefaultPolicyName is a name for the Default policy that is created for every account
DefaultPolicyName = "Default"
// DefaultPolicyDescription is a description for the Default policy that is created for every account
DefaultPolicyDescription = "This is a default policy that allows connections between all the resources"
)
// Rule of ACL for groups
type Rule struct {
// ID of the rule
ID string
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
// Name of the rule visible in the UI
Name string
// Description of the rule visible in the UI
Description string
// Disabled status of rule in the system
Disabled bool
// Source list of groups IDs of peers
Source []string `gorm:"serializer:json"`
// Destination list of groups IDs of peers
Destination []string `gorm:"serializer:json"`
// Flow of the traffic allowed by the rule
Flow TrafficFlowType
}
func (r *Rule) Copy() *Rule {
rule := &Rule{
ID: r.ID,
Name: r.Name,
Description: r.Description,
Disabled: r.Disabled,
Source: make([]string, len(r.Source)),
Destination: make([]string, len(r.Destination)),
Flow: r.Flow,
}
copy(rule.Source, r.Source)
copy(rule.Destination, r.Destination)
return rule
}
// EventMeta returns activity event meta related to this rule
func (r *Rule) EventMeta() map[string]any {
return map[string]any{"name": r.Name}
}
// ToPolicyRule converts a Rule to a PolicyRule object
func (r *Rule) ToPolicyRule() *PolicyRule {
if r == nil {
return nil
}
return &PolicyRule{
ID: r.ID,
Name: r.Name,
Enabled: !r.Disabled,
Description: r.Description,
Destinations: r.Destination,
Sources: r.Source,
Bidirectional: true,
Protocol: PolicyRuleProtocolALL,
Action: PolicyTrafficActionAccept,
}
}
// RuleToPolicy converts a Rule to a Policy query object
func RuleToPolicy(rule *Rule) (*Policy, error) {
if rule == nil {
return nil, fmt.Errorf("rule is empty")
}
return &Policy{
ID: rule.ID,
Name: rule.Name,
Description: rule.Description,
Enabled: !rule.Disabled,
Rules: []*PolicyRule{rule.ToPolicyRule()},
}, nil
}

View File

@ -0,0 +1,69 @@
package accounts
import (
"time"
nbdns "github.com/netbirdio/netbird/dns"
)
// Settings represents Account settings structure that can be modified via API and Dashboard
type Settings struct {
// PeerLoginExpirationEnabled globally enables or disables peer login expiration
PeerLoginExpirationEnabled bool
// PeerLoginExpiration is a setting that indicates when peer login expires.
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
PeerLoginExpiration time.Duration
// GroupsPropagationEnabled allows to propagate auto groups from the user to the peer
GroupsPropagationEnabled bool
// JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName
// and add it to account groups.
JWTGroupsEnabled bool
// JWTGroupsClaimName from which we extract groups name to add it to account groups
JWTGroupsClaimName string
}
// Copy copies the Settings struct
func (s *Settings) Copy() *Settings {
return &Settings{
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
PeerLoginExpiration: s.PeerLoginExpiration,
JWTGroupsEnabled: s.JWTGroupsEnabled,
JWTGroupsClaimName: s.JWTGroupsClaimName,
GroupsPropagationEnabled: s.GroupsPropagationEnabled,
}
}
// Account represents a unique account of the system
type Account struct {
// we have to name column to aid as it collides with Network.Id when work with associations
Id string `gorm:"primaryKey"`
// User.Id it was created by
CreatedBy string
Domain string `gorm:"index"`
DomainCategory string
IsDomainPrimaryAccount bool
SetupKeys map[string]*SetupKey `gorm:"-"`
SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"`
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
Peers map[string]*Peer `gorm:"-"`
PeersG []Peer `json:"-" gorm:"foreignKey:AccountID;references:id"`
Users map[string]*User `gorm:"-"`
UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"`
Groups map[string]*Group `gorm:"-"`
GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
Rules map[string]*Rule `gorm:"-"`
RulesG []Rule `json:"-" gorm:"foreignKey:AccountID;references:id"`
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
Routes map[string]*route.Route `gorm:"-"`
RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`
NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"`
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
// Settings is a dictionary of Account settings
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
}

View File

@ -0,0 +1,22 @@
package accounts
type AccountManager interface {
GetAccount(accountID string) (Account, error)
GetDNSDomain() string
}
type DefaultAccountManager struct {
repository AccountRepository
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
}
func (am *DefaultAccountManager) GetAccount(accountID string) (Account, error) {
return am.repository.findAccountByID(accountID)
}
// GetDNSDomain returns the configured dnsDomain
func (am *DefaultAccountManager) GetDNSDomain() string {
return am.dnsDomain
}

View File

@ -0,0 +1,5 @@
package accounts
type AccountRepository interface {
findAccountByID(accountID string) (Account, error)
}

View File

@ -0,0 +1,150 @@
package server
import (
"net/url"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/util"
)
type (
// Protocol type
Protocol string
// Provider authorization flow type
Provider string
)
const (
UDP Protocol = "udp"
DTLS Protocol = "dtls"
TCP Protocol = "tcp"
HTTP Protocol = "http"
HTTPS Protocol = "https"
NONE Provider = "none"
)
const (
// DefaultDeviceAuthFlowScope defines the bare minimum scope to request in the device authorization flow
DefaultDeviceAuthFlowScope string = "openid"
)
// Config of the Management service
type Config struct {
Stuns []*Host
TURNConfig *TURNConfig
Signal *Host
Datadir string
DataStoreEncryptionKey string
HttpConfig *HttpServerConfig
IdpManagerConfig *idp.Config
DeviceAuthorizationFlow *DeviceAuthorizationFlow
PKCEAuthorizationFlow *PKCEAuthorizationFlow
StoreConfig StoreConfig
}
// GetAuthAudiences returns the audience from the http config and device authorization flow config
func (c Config) GetAuthAudiences() []string {
audiences := []string{c.HttpConfig.AuthAudience}
if c.DeviceAuthorizationFlow != nil && c.DeviceAuthorizationFlow.ProviderConfig.Audience != "" {
audiences = append(audiences, c.DeviceAuthorizationFlow.ProviderConfig.Audience)
}
return audiences
}
// TURNConfig is a config of the TURNCredentialsManager
type TURNConfig struct {
TimeBasedCredentials bool
CredentialsTTL util.Duration
Secret string
Turns []*Host
}
// HttpServerConfig is a config of the HTTP Management service server
type HttpServerConfig struct {
LetsEncryptDomain string
// CertFile is the location of the certificate
CertFile string
// CertKey is the location of the certificate private key
CertKey string
// AuthAudience identifies the recipients that the JWT is intended for (aud in JWT)
AuthAudience string
// AuthIssuer identifies principal that issued the JWT
AuthIssuer string
// AuthUserIDClaim is the name of the claim that used as user ID
AuthUserIDClaim string
// AuthKeysLocation is a location of JWT key set containing the public keys used to verify JWT
AuthKeysLocation string
// OIDCConfigEndpoint is the endpoint of an IDP manager to get OIDC configuration
OIDCConfigEndpoint string
// IdpSignKeyRefreshEnabled identifies the signing key is currently being rotated or not
IdpSignKeyRefreshEnabled bool
}
// Host represents a Wiretrustee host (e.g. STUN, TURN, Signal)
type Host struct {
Proto Protocol
// URI e.g. turns://stun.wiretrustee.com:4430 or signal.wiretrustee.com:10000
URI string
Username string
Password string
}
// DeviceAuthorizationFlow represents Device Authorization Flow information
// that can be used by the client to login initiate a Oauth 2.0 device authorization grant flow
// see https://datatracker.ietf.org/doc/html/rfc8628
type DeviceAuthorizationFlow struct {
Provider string
ProviderConfig ProviderConfig
}
// PKCEAuthorizationFlow represents Authorization Code Flow information
// that can be used by the client to login initiate a Oauth 2.0 authorization code grant flow
// with Proof Key for Code Exchange (PKCE). See https://datatracker.ietf.org/doc/html/rfc7636
type PKCEAuthorizationFlow struct {
ProviderConfig ProviderConfig
}
// ProviderConfig has all attributes needed to initiate a device/pkce authorization flow
type ProviderConfig struct {
// ClientID An IDP application client id
ClientID string
// ClientSecret An IDP application client secret
ClientSecret string
// Domain An IDP API domain
// Deprecated. Use TokenEndpoint and DeviceAuthEndpoint
Domain string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
AuthorizationEndpoint string
// Scopes provides the scopes to be included in the token request
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// RedirectURL handles authorization code from IDP manager
RedirectURLs []string
}
// StoreConfig contains Store configuration
type StoreConfig struct {
Engine StoreEngine
}
// validateURL validates input http url
func validateURL(httpURL string) bool {
_, err := url.ParseRequestURI(httpURL)
return err == nil
}

View File

@ -0,0 +1 @@
package dns

View File

@ -0,0 +1,67 @@
package events
import (
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
)
type EventsManager struct {
store activity.Store
}
func NewEventsManager(store activity.Store) *EventsManager {
return &EventsManager{
store: store,
}
}
// GetEvents returns a list of activity events of an account
func (em *EventsManager) GetEvents(accountID, userID string) ([]*activity.Event, error) {
events, err := em.store.Get(accountID, 0, 10000, true)
if err != nil {
return nil, err
}
// this is a workaround for duplicate activity.UserJoined events that might occur when a user redeems invite.
// we will need to find a better way to handle this.
filtered := make([]*activity.Event, 0)
dups := make(map[string]struct{})
for _, event := range events {
if event.Activity == activity.UserJoined {
key := event.TargetID + event.InitiatorID + event.AccountID + fmt.Sprint(event.Activity)
_, duplicate := dups[key]
if duplicate {
continue
} else {
dups[key] = struct{}{}
}
}
filtered = append(filtered, event)
}
return filtered, nil
}
func (em *EventsManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.Activity,
meta map[string]any) {
go func() {
_, err := em.store.Save(&activity.Event{
Timestamp: time.Now().UTC(),
Activity: activityID,
InitiatorID: initiatorID,
TargetID: targetID,
AccountID: accountID,
Meta: meta,
})
if err != nil {
// todo add metric
log.Errorf("received an error while storing an activity event, error: %s", err)
}
}()
}

View File

@ -0,0 +1,76 @@
package events
// import (
// "testing"
// "time"
//
// "github.com/stretchr/testify/assert"
//
// "github.com/netbirdio/netbird/management/server/activity"
// )
//
// func generateAndStoreEvents(t *testing.T, manager *EventsManager, typ activity.Activity, initiatorID, targetID,
// accountID string, count int) {
// t.Helper()
// for i := 0; i < count; i++ {
// _, err := manager.store.Save(&activity.Event{
// Timestamp: time.Now().UTC(),
// Activity: typ,
// InitiatorID: initiatorID,
// TargetID: targetID,
// AccountID: accountID,
// })
// if err != nil {
// t.Fatal(err)
// }
// }
// }
//
// func TestDefaultAccountManager_GetEvents(t *testing.T) {
// manager, err := createManager(t)
// if err != nil {
// return
// }
//
// accountID := "accountID"
//
// t.Run("get empty events list", func(t *testing.T) {
// events, err := manager.GetEvents(accountID, userID)
// if err != nil {
// return
// }
// assert.Len(t, events, 0)
// _ = manager.eventStore.Close() //nolint
// })
//
// t.Run("get events", func(t *testing.T) {
// generateAndStoreEvents(t, manager, activity.PeerAddedByUser, userID, "peer", accountID, 10)
// events, err := manager.GetEvents(accountID, userID)
// if err != nil {
// return
// }
//
// assert.Len(t, events, 10)
// _ = manager.eventStore.Close() //nolint
// })
//
// t.Run("get events without duplicates", func(t *testing.T) {
// generateAndStoreEvents(t, manager, activity.UserJoined, userID, "", accountID, 10)
// events, err := manager.GetEvents(accountID, userID)
// if err != nil {
// return
// }
// assert.Len(t, events, 1)
// _ = manager.eventStore.Close() //nolint
// })
// }
//
// func createManager(t *testing.T) (*EventsManager, error) {
// t.Helper()
// store, err := createStore(t)
// if err != nil {
// return nil, err
// }
// eventStore := &activity.InMemoryEventStore{}
// return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, false)
// }

View File

@ -0,0 +1,333 @@
package groups
import (
"fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/status"
)
type GroupLinkError struct {
Resource string
Name string
}
func (e *GroupLinkError) Error() string {
return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name)
}
// Group of the peers for ACL
type Group struct {
// ID of the group
ID string
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
// Name visible in the UI
Name string
// Issued of the group
Issued string
// Peers list of the group
Peers []string `gorm:"serializer:json"`
IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
}
// EventMeta returns activity event meta related to the group
func (g *Group) EventMeta() map[string]any {
return map[string]any{"name": g.Name}
}
func (g *Group) Copy() *Group {
group := &Group{
ID: g.ID,
Name: g.Name,
Issued: g.Issued,
Peers: make([]string, len(g.Peers)),
IntegrationReference: g.IntegrationReference,
}
copy(group.Peers, g.Peers)
return group
}
// GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
group, ok := account.Groups[groupID]
if ok {
return group, nil
}
return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
// SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *Group) error {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return err
}
oldGroup, exists := account.Groups[newGroup.ID]
account.Groups[newGroup.ID] = newGroup
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
return err
}
am.updateAccountPeers(account)
// the following snippet tracks the activity and stores the group events in the event store.
// It has to happen after all the operations have been successfully performed.
addedPeers := make([]string, 0)
removedPeers := make([]string, 0)
if exists {
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else {
addedPeers = append(addedPeers, newGroup.Peers...)
am.StoreEvent(userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
}
for _, p := range addedPeers {
peer := account.Peers[p]
if peer == nil {
log.Errorf("peer %s not found under account %s while saving group", p, accountID)
continue
}
am.StoreEvent(userID, peer.ID, accountID, activity.GroupAddedToPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
})
}
for _, p := range removedPeers {
peer := account.Peers[p]
if peer == nil {
log.Errorf("peer %s not found under account %s while saving group", p, accountID)
continue
}
am.StoreEvent(userID, peer.ID, accountID, activity.GroupRemovedFromPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
})
}
return nil
}
// difference returns the elements in `a` that aren't in `b`.
func difference(a, b []string) []string {
mb := make(map[string]struct{}, len(b))
for _, x := range b {
mb[x] = struct{}{}
}
var diff []string
for _, x := range a {
if _, found := mb[x]; !found {
diff = append(diff, x)
}
}
return diff
}
// DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
unlock := am.Store.AcquireAccountLock(accountId)
defer unlock()
account, err := am.Store.GetAccount(accountId)
if err != nil {
return err
}
g, ok := account.Groups[groupID]
if !ok {
return nil
}
// disable a deleting integration group if the initiator is not an admin service user
if g.Issued == GroupIssuedIntegration {
executingUser := account.Users[userId]
if executingUser == nil {
return status.Errorf(status.NotFound, "user not found")
}
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only admins service user can delete integration group")
}
}
// check route links
for _, r := range account.Routes {
for _, g := range r.Groups {
if g == groupID {
return &GroupLinkError{"route", r.NetID}
}
}
}
// check DNS links
for _, dns := range account.NameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
return &GroupLinkError{"name server groups", dns.Name}
}
}
}
// check ACL links
for _, policy := range account.Policies {
for _, rule := range policy.Rules {
for _, src := range rule.Sources {
if src == groupID {
return &GroupLinkError{"policy", policy.Name}
}
}
for _, dst := range rule.Destinations {
if dst == groupID {
return &GroupLinkError{"policy", policy.Name}
}
}
}
}
// check setup key links
for _, setupKey := range account.SetupKeys {
for _, grp := range setupKey.AutoGroups {
if grp == groupID {
return &GroupLinkError{"setup key", setupKey.Name}
}
}
}
// check user links
for _, user := range account.Users {
for _, grp := range user.AutoGroups {
if grp == groupID {
return &GroupLinkError{"user", user.Id}
}
}
}
// check DisabledManagementGroups
for _, disabledMgmGrp := range account.DNSSettings.DisabledManagementGroups {
if disabledMgmGrp == groupID {
return &GroupLinkError{"disabled DNS management groups", g.Name}
}
}
delete(account.Groups, groupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
return err
}
am.StoreEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
am.updateAccountPeers(account)
return nil
}
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
groups := make([]*Group, 0, len(account.Groups))
for _, item := range account.Groups {
groups = append(groups, item)
}
return groups, nil
}
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
add := true
for _, itemID := range group.Peers {
if itemID == peerID {
add = false
break
}
}
if add {
group.Peers = append(group.Peers, peerID)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
return err
}
am.updateAccountPeers(account)
return nil
}
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return err
}
group, ok := account.Groups[groupID]
if !ok {
return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
}
account.Network.IncSerial()
for i, itemID := range group.Peers {
if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(account); err != nil {
return err
}
}
}
am.updateAccountPeers(account)
return nil
}

View File

@ -0,0 +1,607 @@
package server
import (
"context"
"fmt"
"strings"
"time"
pb "github.com/golang/protobuf/proto" // nolint
"github.com/golang/protobuf/ptypes/timestamp"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
gRPCPeer "google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/management_refactor/server/accounts"
"github.com/netbirdio/netbird/management/server/management_refactor/server/peers"
internalStatus "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
)
// GRPCServer an instance of a Management gRPC API server
type GRPCServer struct {
accountManager accounts.AccountManager
wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer
peersUpdateManager *peers.PeersUpdateManager
config *Config
turnCredentialsManager TURNCredentialsManager
jwtValidator *jwtclaims.JWTValidator
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics
ephemeralManager *peers.EphemeralManager
}
// NewServer creates a new Management server
func NewServer(config *Config, accountManager accounts.AccountManager, peersUpdateManager *peers.PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *peers.EphemeralManager) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
var jwtValidator *jwtclaims.JWTValidator
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
jwtValidator, err = jwtclaims.NewJWTValidator(
config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation,
config.HttpConfig.IdpSignKeyRefreshEnabled,
)
if err != nil {
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
}
} else {
log.Debug("unable to use http config to create new jwt middleware")
}
if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
return int64(len(peersUpdateManager.PeerChannels))
})
if err != nil {
return nil, err
}
}
var audience, userIDClaim string
if config.HttpConfig != nil {
audience = config.HttpConfig.AuthAudience
userIDClaim = config.HttpConfig.AuthUserIDClaim
}
jwtClaimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(audience),
jwtclaims.WithUserIDClaim(userIDClaim),
)
return &GRPCServer{
wgKey: key,
// peerKey -> event channel
peersUpdateManager: peersUpdateManager,
accountManager: accountManager,
config: config,
turnCredentialsManager: turnCredentialsManager,
jwtValidator: jwtValidator,
jwtClaimsExtractor: jwtClaimsExtractor,
appMetrics: appMetrics,
ephemeralManager: ephemeralManager,
}, nil
}
func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) {
// todo introduce something more meaningful with the key expiration/rotation
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountGetKeyRequest()
}
now := time.Now().Add(24 * time.Hour)
secs := int64(now.Second())
nanos := int32(now.Nanosecond())
expiresAt := &timestamp.Timestamp{Seconds: secs, Nanos: nanos}
return &proto.ServerKeyResponse{
Key: s.wgKey.PublicKey().String(),
ExpiresAt: expiresAt,
}, nil
}
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
reqStart := time.Now()
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequest()
}
p, ok := gRPCPeer.FromContext(srv.Context())
if ok {
log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
}
syncReq := &proto.SyncRequest{}
peerKey, err := s.parseRequest(req, syncReq)
if err != nil {
return err
}
peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()})
if err != nil {
return mapError(err)
}
err = s.sendInitialSync(peerKey, peer, netMap, srv)
if err != nil {
log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
return err
}
updates := s.peersUpdateManager.CreateChannel(peer.ID)
s.ephemeralManager.OnPeerConnected(peer)
err = s.accountManager.MarkPeerConnected(peerKey.String(), true)
if err != nil {
log.Warnf("failed marking peer as connected %s %v", peerKey, err)
}
if s.config.TURNConfig.TimeBasedCredentials {
s.turnCredentialsManager.SetupRefresh(peer.ID)
}
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
}
// keep a connection to the peer and send updates when available
for {
select {
// condition when there are some updates
case update, open := <-updates:
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
}
if !open {
log.Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(peer)
return nil
}
log.Debugf("received an update for peer %s", peerKey.String())
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
if err != nil {
s.cancelPeerRoutines(peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.SendMsg(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
})
if err != nil {
s.cancelPeerRoutines(peer)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.Debugf("sent an update to peer %s", peerKey.String())
// condition when client <-> server connection has been terminated
case <-srv.Context().Done():
// happens when connection drops, e.g. client disconnects
log.Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(peer)
return srv.Context().Err()
}
}
}
func (s *GRPCServer) cancelPeerRoutines(peer *Peer) {
s.peersUpdateManager.CloseChannel(peer.ID)
s.turnCredentialsManager.CancelRefresh(peer.ID)
_ = s.accountManager.MarkPeerConnected(peer.Key, false)
s.ephemeralManager.OnPeerDisconnected(peer)
}
func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
if s.jwtValidator == nil {
return "", status.Error(codes.Internal, "no jwt validator set")
}
token, err := s.jwtValidator.ValidateAndParse(jwtToken)
if err != nil {
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
}
claims := s.jwtClaimsExtractor.FromToken(token)
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
_, _, err = s.accountManager.GetAccountFromToken(claims)
if err != nil {
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
}
return claims.UserId, nil
}
// maps internal internalStatus.Error to gRPC status.Error
func mapError(err error) error {
if e, ok := internalStatus.FromError(err); ok {
switch e.Type() {
case internalStatus.PermissionDenied:
return status.Errorf(codes.PermissionDenied, e.Message)
case internalStatus.Unauthorized:
return status.Errorf(codes.PermissionDenied, e.Message)
case internalStatus.Unauthenticated:
return status.Errorf(codes.PermissionDenied, e.Message)
case internalStatus.PreconditionFailed:
return status.Errorf(codes.FailedPrecondition, e.Message)
case internalStatus.NotFound:
return status.Errorf(codes.NotFound, e.Message)
default:
}
}
log.Errorf("got an unhandled error: %s", err)
return status.Errorf(codes.Internal, "failed handling request")
}
func extractPeerMeta(loginReq *proto.LoginRequest) PeerSystemMeta {
return PeerSystemMeta{
Hostname: loginReq.GetMeta().GetHostname(),
GoOS: loginReq.GetMeta().GetGoOS(),
Kernel: loginReq.GetMeta().GetKernel(),
Core: loginReq.GetMeta().GetCore(),
Platform: loginReq.GetMeta().GetPlatform(),
OS: loginReq.GetMeta().GetOS(),
WtVersion: loginReq.GetMeta().GetWiretrusteeVersion(),
UIVersion: loginReq.GetMeta().GetUiVersion(),
}
}
func (s *GRPCServer) parseRequest(req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
log.Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, parsed)
if err != nil {
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message")
}
return peerKey, nil
}
// Login endpoint first checks whether peer is registered under any account
// In case it is, the login is successful
// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer.
// In case of the successful registration login is also successful
func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
reqStart := time.Now()
defer func() {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart))
}
}()
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequest()
}
p, ok := gRPCPeer.FromContext(ctx)
if ok {
log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, p.Addr.String())
}
loginReq := &proto.LoginRequest{}
peerKey, err := s.parseRequest(req, loginReq)
if err != nil {
return nil, err
}
if loginReq.GetMeta() == nil {
msg := status.Errorf(codes.FailedPrecondition,
"peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(),
p.Addr.String())
log.Warn(msg)
return nil, msg
}
userID := ""
// JWT token is not always provided, it is fine for userID to be empty cuz it might be that peer is already registered,
// or it uses a setup key to register.
if loginReq.GetJwtToken() != "" {
userID, err = s.validateToken(loginReq.GetJwtToken())
if err != nil {
log.Warnf("failed validating JWT token sent from peer %s", peerKey)
return nil, mapError(err)
}
}
var sshKey []byte
if loginReq.GetPeerKeys() != nil {
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
}
peer, netMap, err := s.accountManager.LoginPeer(PeerLogin{
WireGuardPubKey: peerKey.String(),
SSHKey: string(sshKey),
Meta: extractPeerMeta(loginReq),
UserID: userID,
SetupKey: loginReq.GetSetupKey(),
})
if err != nil {
log.Warnf("failed logging in peer %s", peerKey)
return nil, mapError(err)
}
// if the login request contains setup key then it is a registration request
if loginReq.GetSetupKey() != "" {
s.ephemeralManager.OnPeerDisconnected(peer)
}
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
if err != nil {
log.Warnf("failed encrypting peer %s message", peer.ID)
return nil, status.Errorf(codes.Internal, "failed logging in peer")
}
return &proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
}, nil
}
func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
switch configProto {
case UDP:
return proto.HostConfig_UDP
case DTLS:
return proto.HostConfig_DTLS
case HTTP:
return proto.HostConfig_HTTP
case HTTPS:
return proto.HostConfig_HTTPS
case TCP:
return proto.HostConfig_TCP
default:
panic(fmt.Errorf("unexpected config protocol type %v", configProto))
}
}
func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig {
if config == nil {
return nil
}
var stuns []*proto.HostConfig
for _, stun := range config.Stuns {
stuns = append(stuns, &proto.HostConfig{
Uri: stun.URI,
Protocol: ToResponseProto(stun.Proto),
})
}
var turns []*proto.ProtectedHostConfig
for _, turn := range config.TURNConfig.Turns {
var username string
var password string
if turnCredentials != nil {
username = turnCredentials.Username
password = turnCredentials.Password
} else {
username = turn.Username
password = turn.Password
}
turns = append(turns, &proto.ProtectedHostConfig{
HostConfig: &proto.HostConfig{
Uri: turn.URI,
Protocol: ToResponseProto(turn.Proto),
},
User: username,
Password: password,
})
}
return &proto.WiretrusteeConfig{
Stuns: stuns,
Turns: turns,
Signal: &proto.HostConfig{
Uri: config.Signal.URI,
Protocol: ToResponseProto(config.Signal.Proto),
},
}
}
func toPeerConfig(peer *Peer, network *Network, dnsName string) *proto.PeerConfig {
netmask, _ := network.Net.Mask.Size()
fqdn := peer.FQDN(dnsName)
return &proto.PeerConfig{
Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network
SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled},
Fqdn: fqdn,
}
}
func toRemotePeerConfig(peers []*Peer, dnsName string) []*proto.RemotePeerConfig {
remotePeers := []*proto.RemotePeerConfig{}
for _, rPeer := range peers {
fqdn := rPeer.FQDN(dnsName)
remotePeers = append(remotePeers, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: fqdn,
})
}
return remotePeers
}
func toSyncResponse(config *Config, peer *Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
remotePeers := toRemotePeerConfig(networkMap.Peers, dnsName)
routesUpdate := toProtocolRoutes(networkMap.Routes)
dnsUpdate := toProtocolDNSConfig(networkMap.DNSConfig)
offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
return &proto.SyncResponse{
WiretrusteeConfig: wtConfig,
PeerConfig: pConfig,
RemotePeers: remotePeers,
RemotePeersIsEmpty: len(remotePeers) == 0,
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
PeerConfig: pConfig,
RemotePeers: remotePeers,
OfflinePeers: offlinePeers,
RemotePeersIsEmpty: len(remotePeers) == 0,
Routes: routesUpdate,
DNSConfig: dnsUpdate,
FirewallRules: firewallRules,
FirewallRulesIsEmpty: len(firewallRules) == 0,
},
}
}
// IsHealthy indicates whether the service is healthy
func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) {
return &proto.Empty{}, nil
}
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error {
// make secret time based TURN credentials optional
var turnCredentials *TURNCredentials
if s.config.TURNConfig.TimeBasedCredentials {
creds := s.turnCredentialsManager.GenerateCredentials()
turnCredentials = &creds
} else {
turnCredentials = nil
}
plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain())
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {
return status.Errorf(codes.Internal, "error handling request")
}
err = srv.Send(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
})
if err != nil {
log.Errorf("failed sending SyncResponse %v", err)
return status.Errorf(codes.Internal, "error handling request")
}
return nil
}
// GetDeviceAuthorizationFlow returns a device authorization flow information
// This is used for initiating an Oauth 2 device authorization grant flow
// which will be used by our clients to Login
func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)
log.Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{})
if err != nil {
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
log.Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(NONE) {
return nil, status.Error(codes.NotFound, "no device authorization flow information available")
}
provider, ok := proto.DeviceAuthorizationFlowProvider_value[strings.ToUpper(s.config.DeviceAuthorizationFlow.Provider)]
if !ok {
return nil, status.Errorf(codes.InvalidArgument, "no provider found in the protocol for %s", s.config.DeviceAuthorizationFlow.Provider)
}
flowInfoResp := &proto.DeviceAuthorizationFlow{
Provider: proto.DeviceAuthorizationFlowProvider(provider),
ProviderConfig: &proto.ProviderConfig{
ClientID: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientID,
ClientSecret: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientSecret,
Domain: s.config.DeviceAuthorizationFlow.ProviderConfig.Domain,
Audience: s.config.DeviceAuthorizationFlow.ProviderConfig.Audience,
DeviceAuthEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint,
TokenEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint,
Scope: s.config.DeviceAuthorizationFlow.ProviderConfig.Scope,
UseIDToken: s.config.DeviceAuthorizationFlow.ProviderConfig.UseIDToken,
},
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
if err != nil {
return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information")
}
return &proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
}, nil
}
// GetPKCEAuthorizationFlow returns a pkce authorization flow information
// This is used for initiating an Oauth 2 pkce authorization grant flow
// which will be used by our clients to Login
func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey)
log.Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{})
if err != nil {
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
log.Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
if s.config.PKCEAuthorizationFlow == nil {
return nil, status.Error(codes.NotFound, "no pkce authorization flow information available")
}
flowInfoResp := &proto.PKCEAuthorizationFlow{
ProviderConfig: &proto.ProviderConfig{
Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience,
ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID,
ClientSecret: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientSecret,
TokenEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint,
AuthorizationEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint,
Scope: s.config.PKCEAuthorizationFlow.ProviderConfig.Scope,
RedirectURLs: s.config.PKCEAuthorizationFlow.ProviderConfig.RedirectURLs,
UseIDToken: s.config.PKCEAuthorizationFlow.ProviderConfig.UseIDToken,
},
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
if err != nil {
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")
}
return &proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
}, nil
}

View File

@ -0,0 +1,286 @@
package nameservers
import (
"errors"
"regexp"
"unicode/utf8"
"github.com/miekg/dns"
"github.com/rs/xid"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/status"
)
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
nsGroup, found := account.NameServerGroups[nsGroupID]
if found {
return nsGroup.Copy(), nil
}
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
}
// CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
newNSGroup := &nbdns.NameServerGroup{
ID: xid.New().String(),
Name: name,
Description: description,
NameServers: nameServerList,
Groups: groups,
Enabled: enabled,
Primary: primary,
Domains: domains,
SearchDomainsEnabled: searchDomainEnabled,
}
err = validateNameServerGroup(false, newNSGroup, account)
if err != nil {
return nil, err
}
if account.NameServerGroups == nil {
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup)
}
account.NameServerGroups[newNSGroup.ID] = newNSGroup
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
am.updateAccountPeers(account)
am.StoreEvent(userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
return newNSGroup.Copy(), nil
}
// SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
if nsGroupToSave == nil {
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
}
account, err := am.Store.GetAccount(accountID)
if err != nil {
return err
}
err = validateNameServerGroup(true, nsGroupToSave, account)
if err != nil {
return err
}
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
if err != nil {
return err
}
am.updateAccountPeers(account)
am.StoreEvent(userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
return nil
}
// DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return err
}
nsGroup := account.NameServerGroups[nsGroupID]
if nsGroup == nil {
return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID)
}
delete(account.NameServerGroups, nsGroupID)
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
if err != nil {
return err
}
am.updateAccountPeers(account)
am.StoreEvent(userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
return nil
}
// ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountLock(accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups))
for _, item := range account.NameServerGroups {
nsGroups = append(nsGroups, item.Copy())
}
return nsGroups, nil
}
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
nsGroupID := ""
if existingGroup {
nsGroupID = nameserverGroup.ID
_, found := account.NameServerGroups[nsGroupID]
if !found {
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID)
}
}
err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled)
if err != nil {
return err
}
err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
if err != nil {
return err
}
err = validateNSList(nameserverGroup.NameServers)
if err != nil {
return err
}
err = validateGroups(nameserverGroup.Groups, account.Groups)
if err != nil {
return err
}
return nil
}
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
if !primary && len(domains) == 0 {
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
" it should be primary or have at least one domain")
}
if primary && len(domains) != 0 {
return status.Errorf(status.InvalidArgument, "nameserver group primary status is true and domains are not empty,"+
" you should set either primary or domain")
}
if primary && searchDomainsEnabled {
return status.Errorf(status.InvalidArgument, "nameserver group primary status is true and search domains is enabled,"+
" you should not set search domains for primary nameservers")
}
for _, domain := range domains {
if err := validateDomain(domain); err != nil {
return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s %q", domain, err)
}
}
return nil
}
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
}
for _, nsGroup := range nsGroupMap {
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name)
}
}
return nil
}
func validateNSList(list []nbdns.NameServer) error {
nsListLenght := len(list)
if nsListLenght == 0 || nsListLenght > 2 {
return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 2, got %d", len(list))
}
return nil
}
func validateGroups(list []string, groups map[string]*Group) error {
if len(list) == 0 {
return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty")
}
for _, id := range list {
if id == "" {
return status.Errorf(status.InvalidArgument, "group ID should not be empty string")
}
found := false
for groupID := range groups {
if id == groupID {
found = true
break
}
}
if !found {
return status.Errorf(status.InvalidArgument, "group id %s not found", id)
}
}
return nil
}
func validateDomain(domain string) error {
domainMatcher := regexp.MustCompile(domainPattern)
if !domainMatcher.MatchString(domain) {
return errors.New("domain should consists of only letters, numbers, and hyphens with no leading, trailing hyphens, or spaces")
}
labels, valid := dns.IsDomainName(domain)
if !valid {
return errors.New("invalid domain name")
}
if labels < 2 {
return errors.New("domain should consists of a minimum of two labels")
}
return nil
}

View File

@ -0,0 +1,148 @@
package network
import (
"math/rand"
"net"
"sync"
"time"
"github.com/c-robinson/iplib"
"github.com/rs/xid"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/management_refactor/server/peers"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
)
const (
// SubnetSize is a size of the subnet of the global network, e.g. 100.77.0.0/16
SubnetSize = 16
// NetSize is a global network size 100.64.0.0/10
NetSize = 10
// AllowedIPsFormat generates Wireguard AllowedIPs format (e.g. 100.64.30.1/32)
AllowedIPsFormat = "%s/32"
)
type NetworkMap struct {
Peers []*peers.Peer
Network *Network
Routes []*route.Route
DNSConfig nbdns.Config
OfflinePeers []*peers.Peer
FirewallRules []*FirewallRule
}
type Network struct {
Identifier string `json:"id"`
Net net.IPNet `gorm:"serializer:gob"`
Dns string
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
// Used to synchronize state to the client apps.
Serial uint64
mu sync.Mutex `json:"-" gorm:"-"`
}
// NewNetwork creates a new Network initializing it with a Serial=0
// It takes a random /16 subnet from 100.64.0.0/10 (64 different subnets)
func NewNetwork() *Network {
n := iplib.NewNet4(net.ParseIP("100.64.0.0"), NetSize)
sub, _ := n.Subnet(SubnetSize)
s := rand.NewSource(time.Now().Unix())
r := rand.New(s)
intn := r.Intn(len(sub))
return &Network{
Identifier: xid.New().String(),
Net: sub[intn].IPNet,
Dns: "",
Serial: 0}
}
// IncSerial increments Serial by 1 reflecting that the network state has been changed
func (n *Network) IncSerial() {
n.mu.Lock()
defer n.mu.Unlock()
n.Serial = n.Serial + 1
}
// CurrentSerial returns the Network.Serial of the network (latest state id)
func (n *Network) CurrentSerial() uint64 {
n.mu.Lock()
defer n.mu.Unlock()
return n.Serial
}
func (n *Network) Copy() *Network {
return &Network{
Identifier: n.Identifier,
Net: n.Net,
Dns: n.Dns,
Serial: n.Serial,
}
}
// AllocatePeerIP pics an available IP from an net.IPNet.
// This method considers already taken IPs and reuses IPs if there are gaps in takenIps
// E.g. if ipNet=100.30.0.0/16 and takenIps=[100.30.0.1, 100.30.0.4] then the result would be 100.30.0.2 or 100.30.0.3
func AllocatePeerIP(ipNet net.IPNet, takenIps []net.IP) (net.IP, error) {
takenIPMap := make(map[string]struct{})
takenIPMap[ipNet.IP.String()] = struct{}{}
for _, ip := range takenIps {
takenIPMap[ip.String()] = struct{}{}
}
ips, _ := generateIPs(&ipNet, takenIPMap)
if len(ips) == 0 {
return nil, status.Errorf(status.PreconditionFailed, "failed allocating new IP for the ipNet %s - network is out of IPs", ipNet.String())
}
// pick a random IP
s := rand.NewSource(time.Now().Unix())
r := rand.New(s)
intn := r.Intn(len(ips))
return ips[intn], nil
}
// generateIPs generates a list of all possible IPs of the given network excluding IPs specified in the exclusion list
func generateIPs(ipNet *net.IPNet, exclusions map[string]struct{}) ([]net.IP, int) {
var ips []net.IP
for ip := ipNet.IP.Mask(ipNet.Mask); ipNet.Contains(ip); incIP(ip) {
if _, ok := exclusions[ip.String()]; !ok && ip[3] != 0 {
ips = append(ips, copyIP(ip))
}
}
// remove network address, broadcast and Fake DNS resolver address
lenIPs := len(ips)
switch {
case lenIPs < 2:
return ips, lenIPs
case lenIPs < 3:
return ips[1 : len(ips)-1], lenIPs - 2
default:
return ips[1 : len(ips)-2], lenIPs - 3
}
}
func copyIP(ip net.IP) net.IP {
dup := make(net.IP, len(ip))
copy(dup, ip)
return dup
}
func incIP(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
}
}
}

View File

@ -0,0 +1,161 @@
package network
import (
"strconv"
"strings"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/management_refactor/server/access_control"
"github.com/netbirdio/netbird/management/server/management_refactor/server/peers"
)
type NetworkManager interface {
GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap
}
type DefaultNetworkManager struct {
accessControlManager access_control.AccessControlManager
}
func (nm *DefaultNetworkManager) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap {
aclPeers, firewallRules := getPeerConnectionResources(peerID)
// exclude expired peers
var peersToConnect []*peers.Peer
var expiredPeers []*peers.Peer
for _, p := range aclPeers {
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
if a.Settings.PeerLoginExpirationEnabled && expired {
expiredPeers = append(expiredPeers, p)
continue
}
peersToConnect = append(peersToConnect, p)
}
routesUpdate := a.getRoutesToSync(peerID, peersToConnect)
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
dnsUpdate := nbdns.Config{
ServiceEnable: dnsManagementStatus,
}
if dnsManagementStatus {
var zones []nbdns.CustomZone
peersCustomZone := getPeersCustomZone(a, dnsDomain)
if peersCustomZone.Domain != "" {
zones = append(zones, peersCustomZone)
}
dnsUpdate.CustomZones = zones
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
}
return &NetworkMap{
Peers: peersToConnect,
Network: a.Network.Copy(),
Routes: routesUpdate,
DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
}
}
// getPeerConnectionResources for a given peer
//
// This function returns the list of peers and firewall rules that are applicable to a given peer.
func (nm *DefaultNetworkManager) getPeerConnectionResources(peerID string) ([]*Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator()
for _, policy := range nm.accessControlManager.Policies {
if !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID)
destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID)
if rule.Bidirectional {
if peerInSources {
generateResources(rule, destinationPeers, firewallRuleDirectionIN)
}
if peerInDestinations {
generateResources(rule, sourcePeers, firewallRuleDirectionOUT)
}
}
if peerInSources {
generateResources(rule, destinationPeers, firewallRuleDirectionOUT)
}
if peerInDestinations {
generateResources(rule, sourcePeers, firewallRuleDirectionIN)
}
}
}
return getAccumulatedResources()
}
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
//
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be
// generated. The accumulator function returns the result of all the generator calls.
func (nm *DefaultNetworkManager) connResourcesGenerator() (func(*access_control.PolicyRule, []*peers.Peer, int), func() ([]*peers.Peer, []*access_control.FirewallRule)) {
rulesExists := make(map[string]struct{})
peersExists := make(map[string]struct{})
rules := make([]*FirewallRule, 0)
peers := make([]*peers.Peer, 0)
all, err := a.GetGroupAll()
if err != nil {
log.Errorf("failed to get group all: %v", err)
all = &Group{}
}
return func(rule *PolicyRule, groupPeers []*Peer, direction int) {
isAll := (len(all.Peers) - 1) == len(groupPeers)
for _, peer := range groupPeers {
if peer == nil {
continue
}
if _, ok := peersExists[peer.ID]; !ok {
peers = append(peers, peer)
peersExists[peer.ID] = struct{}{}
}
fr := FirewallRule{
PeerIP: peer.IP.String(),
Direction: direction,
Action: string(rule.Action),
Protocol: string(rule.Protocol),
}
if isAll {
fr.PeerIP = "0.0.0.0"
}
ruleID := (rule.ID + fr.PeerIP + strconv.Itoa(direction) +
fr.Protocol + fr.Action + strings.Join(rule.Ports, ","))
if _, ok := rulesExists[ruleID]; ok {
continue
}
rulesExists[ruleID] = struct{}{}
if len(rule.Ports) == 0 {
rules = append(rules, &fr)
continue
}
for _, port := range rule.Ports {
pr := fr // clone rule and add set new port
pr.Port = port
rules = append(rules, &pr)
}
}
}, func() ([]*Peer, []*FirewallRule) {
return peers, rules
}
}

View File

@ -0,0 +1,225 @@
package peers
import (
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/management_refactor/accounts"
)
const (
ephemeralLifeTime = 10 * time.Minute
)
var (
timeNow = time.Now
)
type ephemeralPeer struct {
id string
account *accounts.Account
deadline time.Time
next *ephemeralPeer
}
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
// in worst case we will get invalid error message in this manager.
// EphemeralManager keep a list of ephemeral peers. After ephemeralLifeTime inactivity the peer will be deleted
// automatically. Inactivity means the peer disconnected from the Management server.
type EphemeralManager struct {
store Store
accountManager AccountManager
headPeer *ephemeralPeer
tailPeer *ephemeralPeer
peersLock sync.Mutex
timer *time.Timer
}
// NewEphemeralManager instantiate new EphemeralManager
func NewEphemeralManager(store Store, accountManager AccountManager) *EphemeralManager {
return &EphemeralManager{
store: store,
accountManager: accountManager,
}
}
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
// head.
func (e *EphemeralManager) LoadInitialPeers() {
e.peersLock.Lock()
defer e.peersLock.Unlock()
e.loadEphemeralPeers()
if e.headPeer != nil {
e.timer = time.AfterFunc(ephemeralLifeTime, e.cleanup)
}
}
// Stop timer
func (e *EphemeralManager) Stop() {
e.peersLock.Lock()
defer e.peersLock.Unlock()
if e.timer != nil {
e.timer.Stop()
}
}
// OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer
// is active the manager will not delete it while it is active.
func (e *EphemeralManager) OnPeerConnected(peer *Peer) {
if !peer.Ephemeral {
return
}
log.Tracef("remove peer from ephemeral list: %s", peer.ID)
e.peersLock.Lock()
defer e.peersLock.Unlock()
e.removePeer(peer.ID)
// stop the unnecessary timer
if e.headPeer == nil && e.timer != nil {
e.timer.Stop()
e.timer = nil
}
}
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
// is inactive it will be deleted after the ephemeralLifeTime period.
func (e *EphemeralManager) OnPeerDisconnected(peer *Peer) {
if !peer.Ephemeral {
return
}
log.Tracef("add peer to ephemeral list: %s", peer.ID)
a, err := e.store.GetAccountByPeerID(peer.ID)
if err != nil {
log.Errorf("failed to add peer to ephemeral list: %s", err)
return
}
e.peersLock.Lock()
defer e.peersLock.Unlock()
if e.isPeerOnList(peer.ID) {
return
}
e.addPeer(peer.ID, a, newDeadLine())
if e.timer == nil {
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup)
}
}
func (e *EphemeralManager) loadEphemeralPeers() {
accounts := e.store.GetAllAccounts()
t := newDeadLine()
count := 0
for _, a := range accounts {
for id, p := range a.Peers {
if p.Ephemeral {
count++
e.addPeer(id, a, t)
}
}
}
log.Debugf("loaded ephemeral peer(s): %d", count)
}
func (e *EphemeralManager) cleanup() {
log.Tracef("on ephemeral cleanup")
deletePeers := make(map[string]*ephemeralPeer)
e.peersLock.Lock()
now := timeNow()
for p := e.headPeer; p != nil; p = p.next {
if now.Before(p.deadline) {
break
}
deletePeers[p.id] = p
e.headPeer = p.next
if p.next == nil {
e.tailPeer = nil
}
}
if e.headPeer != nil {
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup)
} else {
e.timer = nil
}
e.peersLock.Unlock()
for id, p := range deletePeers {
log.Debugf("delete ephemeral peer: %s", id)
err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator)
if err != nil {
log.Tracef("failed to delete ephemeral peer: %s", err)
}
}
}
func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) {
ep := &ephemeralPeer{
id: id,
account: account,
deadline: deadline,
}
if e.headPeer == nil {
e.headPeer = ep
}
if e.tailPeer != nil {
e.tailPeer.next = ep
}
e.tailPeer = ep
}
func (e *EphemeralManager) removePeer(id string) {
if e.headPeer == nil {
return
}
if e.headPeer.id == id {
e.headPeer = e.headPeer.next
if e.tailPeer.id == id {
e.tailPeer = nil
}
return
}
for p := e.headPeer; p.next != nil; p = p.next {
if p.next.id == id {
// if we remove the last element from the chain then set the last-1 as tail
if e.tailPeer.id == id {
e.tailPeer = p
}
p.next = p.next.next
return
}
}
}
func (e *EphemeralManager) isPeerOnList(id string) bool {
for p := e.headPeer; p != nil; p = p.next {
if p.id == id {
return true
}
}
return false
}
func newDeadLine() time.Time {
return timeNow().Add(ephemeralLifeTime)
}

View File

@ -0,0 +1,186 @@
package peers
import (
"fmt"
"net"
"time"
log "github.com/sirupsen/logrus"
)
// Peer represents a machine connected to the network.
// The Peer is a WireGuard peer identified by a public key
type Peer struct {
// ID is an internal ID of the peer
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index;uniqueIndex:idx_peers_account_id_ip"`
// WireGuard public key
Key string `gorm:"index"`
// A setup key this peer was registered with
SetupKey string
// IP address of the Peer
IP net.IP `gorm:"uniqueIndex:idx_peers_account_id_ip"`
// Meta is a Peer system meta data
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
// Name is peer's name (machine name)
Name string
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
DNSLabel string
// Status peer's management connection status
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
// The user ID that registered the peer
UserID string
// SSHKey is a public SSH key of the peer
SSHKey string
// SSHEnabled indicates whether SSH server is enabled on the peer
SSHEnabled bool
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
// Works with LastLogin
LoginExpirationEnabled bool
// LastLogin the time when peer performed last login operation
LastLogin time.Time
// Indicate ephemeral peer attribute
Ephemeral bool
}
// PeerSystemMeta is a metadata of a Peer machine system
type PeerSystemMeta struct {
Hostname string
GoOS string
Kernel string
Core string
Platform string
OS string
WtVersion string
UIVersion string
}
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
return p.Hostname == other.Hostname &&
p.GoOS == other.GoOS &&
p.Kernel == other.Kernel &&
p.Core == other.Core &&
p.Platform == other.Platform &&
p.OS == other.OS &&
p.WtVersion == other.WtVersion &&
p.UIVersion == other.UIVersion
}
// AddedWithSSOLogin indicates whether this peer has been added with an SSO login by a user.
func (p *Peer) AddedWithSSOLogin() bool {
return p.UserID != ""
}
// Copy copies Peer object
func (p *Peer) Copy() *Peer {
peerStatus := p.Status
if peerStatus != nil {
peerStatus = p.Status.Copy()
}
return &Peer{
ID: p.ID,
AccountID: p.AccountID,
Key: p.Key,
SetupKey: p.SetupKey,
IP: p.IP,
Meta: p.Meta,
Name: p.Name,
DNSLabel: p.DNSLabel,
Status: peerStatus,
UserID: p.UserID,
SSHKey: p.SSHKey,
SSHEnabled: p.SSHEnabled,
LoginExpirationEnabled: p.LoginExpirationEnabled,
LastLogin: p.LastLogin,
Ephemeral: p.Ephemeral,
}
}
// UpdateMetaIfNew updates peer's system metadata if new information is provided
// returns true if meta was updated, false otherwise
func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) bool {
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
if meta.UIVersion == "" {
meta.UIVersion = p.Meta.UIVersion
}
if p.Meta.isEqual(meta) {
return false
}
p.Meta = meta
return true
}
// MarkLoginExpired marks peer's status expired or not
func (p *Peer) MarkLoginExpired(expired bool) {
newStatus := p.Status.Copy()
newStatus.LoginExpired = expired
if expired {
newStatus.Connected = false
}
p.Status = newStatus
}
// LoginExpired indicates whether the peer's login has expired or not.
// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired.
// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired).
// Login expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property.
// Login expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled.
// Only peers added by interactive SSO login can be expired.
func (p *Peer) LoginExpired(expiresIn time.Duration) (bool, time.Duration) {
if !p.AddedWithSSOLogin() || !p.LoginExpirationEnabled {
return false, 0
}
expiresAt := p.LastLogin.Add(expiresIn)
now := time.Now()
timeLeft := expiresAt.Sub(now)
return timeLeft <= 0, timeLeft
}
// FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain
func (p *Peer) FQDN(dnsDomain string) string {
if dnsDomain == "" {
return ""
}
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
}
// EventMeta returns activity event meta related to the peer
func (p *Peer) EventMeta(dnsDomain string) map[string]any {
return map[string]any{"name": p.Name, "fqdn": p.FQDN(dnsDomain), "ip": p.IP}
}
// Copy PeerStatus
func (p *PeerStatus) Copy() *PeerStatus {
return &PeerStatus{
LastSeen: p.LastSeen,
Connected: p.Connected,
LoginExpired: p.LoginExpired,
}
}
// UpdateLastLogin and set login expired false
func (p *Peer) UpdateLastLogin() {
p.LastLogin = time.Now().UTC()
newStatus := p.Status.Copy()
newStatus.LoginExpired = false
p.Status = newStatus
}
func (p *Peer) CheckAndUpdatePeerSSHKey(newSSHKey string) bool {
if len(newSSHKey) == 0 {
log.Debugf("no new SSH key provided for peer %s, skipping update", p.ID)
return false
}
if p.SSHKey == newSSHKey {
log.Debugf("same SSH key provided for peer %s, skipping update", p.ID)
return false
}
p.SSHKey = newSSHKey
return true
}

View File

@ -0,0 +1,6 @@
package peers
type PeerRepository interface {
findPeerByPubKey(pubKey string) (Peer, error)
updatePeer(peer Peer) error
}

View File

@ -0,0 +1,214 @@
package peers
import (
"time"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/management_refactor/server/accounts"
"github.com/netbirdio/netbird/management/server/management_refactor/server/events"
"github.com/netbirdio/netbird/management/server/management_refactor/server/users"
"github.com/netbirdio/netbird/management/server/status"
)
// PeerLogin used as a data object between the gRPC API and AccountManager on Login request.
type PeerLogin struct {
// WireGuardPubKey is a peers WireGuard public key
WireGuardPubKey string
// SSHKey is a peer's ssh key. Can be empty (e.g., old version do not provide it, or this feature is disabled)
SSHKey string
// Meta is the system information passed by peer, must be always present.
Meta PeerSystemMeta
// UserID indicates that JWT was used to log in, and it was valid. Can be empty when SetupKey is used or auth is not required.
UserID string
// AccountID indicates that JWT was used to log in, and it was valid. Can be empty when SetupKey is used or auth is not required.
AccountID string
// SetupKey references to a server.SetupKey to log in. Can be empty when UserID is used or auth is not required.
SetupKey string
}
type PeerStatus struct {
// LastSeen is the last time peer was connected to the management service
LastSeen time.Time
// Connected indicates whether peer is connected to the management service or not
Connected bool
// LoginExpired
LoginExpired bool
}
// PeerSync used as a data object between the gRPC API and AccountManager on Sync request.
type PeerSync struct {
// WireGuardPubKey is a peers WireGuard public key
WireGuardPubKey string
}
type PeersManager interface {
LoginPeer(login PeerLogin) (*Peer, *accounts.NetworkMap, error)
}
type DefaultPeersManager struct {
repository PeerRepository
userManager users.UserManager
accountManager accounts.AccountManager
eventsManager events.EventsManager
}
// LoginPeer logs in or registers a peer.
// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so.
func (pm *DefaultPeersManager) LoginPeer(login PeerLogin) (*Peer, *accounts.NetworkMap, error) {
peer, err := pm.repository.findPeerByPubKey(login.WireGuardPubKey)
if err != nil {
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
}
if peer.AddedWithSSOLogin() {
user, err := pm.userManager.GetUser(peer.UserID)
if err != nil {
return nil, nil, err
}
if user.IsBlocked() {
return nil, nil, status.Errorf(status.PermissionDenied, "user is blocked")
}
}
account, err := pm.accountManager.GetAccount(peer.AccountID)
if err != nil {
return nil, nil, err
}
// this flag prevents unnecessary calls to the persistent store.
shouldStorePeer := false
updateRemotePeers := false
if peerLoginExpired(peer, account) {
err = checkAuth(login.UserID, peer)
if err != nil {
return nil, nil, err
}
// If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer.
peer.UpdateLastLogin()
updateRemotePeers = true
shouldStorePeer = true
pm.eventsManager.StoreEvent(login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(pm.accountManager.GetDNSDomain()))
}
if peer.UpdateMetaIfNew(login.Meta) {
shouldStorePeer = true
}
if peer.CheckAndUpdatePeerSSHKey(login.SSHKey) {
shouldStorePeer = true
}
if shouldStorePeer {
err := pm.repository.updatePeer(peer)
if err != nil {
return nil, nil, err
}
}
if updateRemotePeers {
am.updateAccountPeers(account)
}
return peer, account.GetPeerNetworkMap(peer.ID, pm.accountManager.GetDNSDomain()), nil
}
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (pm *DefaultPeersManager) SyncPeer(sync PeerSync) (*Peer, *accounts.NetworkMap, error) {
// we found the peer, and we follow a normal login flow
// unlock := am.Store.AcquireAccountLock(account.Id)
// defer unlock()
peer, err := pm.repository.findPeerByPubKey(sync.WireGuardPubKey)
if err != nil {
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
}
if peer.AddedWithSSOLogin() {
user, err := pm.userManager.GetUser(peer.UserID)
if err != nil {
return nil, nil, err
}
if user.IsBlocked() {
return nil, nil, status.Errorf(status.PermissionDenied, "user is blocked")
}
}
account, err := pm.accountManager.GetAccount(peer.AccountID)
if err != nil {
return nil, nil, err
}
if peerLoginExpired(peer, account) {
return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
}
return &peer, account.GetPeerNetworkMap(peer.ID, pm.accountManager.GetDNSDomain()), nil
}
func (pm *DefaultPeersManager) GetNetworkMap(peerID string, dnsDomain string) (*accounts.NetworkMap, error) {
aclPeers, firewallRules := a.getPeerConnectionResources(peerID)
// exclude expired peers
var peersToConnect []*Peer
var expiredPeers []*Peer
for _, p := range aclPeers {
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
if a.Settings.PeerLoginExpirationEnabled && expired {
expiredPeers = append(expiredPeers, p)
continue
}
peersToConnect = append(peersToConnect, p)
}
routesUpdate := a.getRoutesToSync(peerID, peersToConnect)
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
dnsUpdate := nbdns.Config{
ServiceEnable: dnsManagementStatus,
}
if dnsManagementStatus {
var zones []nbdns.CustomZone
peersCustomZone := getPeersCustomZone(a, dnsDomain)
if peersCustomZone.Domain != "" {
zones = append(zones, peersCustomZone)
}
dnsUpdate.CustomZones = zones
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
}
return &NetworkMap{
Peers: peersToConnect,
Network: a.Network.Copy(),
Routes: routesUpdate,
DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
}
}
func peerLoginExpired(peer Peer, account accounts.Account) bool {
expired, expiresIn := peer.LoginExpired(account.Settings.PeerLoginExpiration)
expired = account.Settings.PeerLoginExpirationEnabled && expired
if expired || peer.Status.LoginExpired {
log.Debugf("peer's %s login expired %v ago", peer.ID, expiresIn)
return true
}
return false
}
func checkAuth(loginUserID string, peer Peer) error {
if loginUserID == "" {
// absence of a user ID indicates that JWT wasn't provided.
return status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
}
if peer.UserID != loginUserID {
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID)
return status.Errorf(status.Unauthenticated, "can't login")
}
return nil
}

View File

@ -0,0 +1,153 @@
package peers
import (
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const channelBufferSize = 100
type UpdateMessage struct {
Update *proto.SyncResponse
}
type PeersUpdateManager struct {
// PeerChannels is an update channel indexed by Peer.ID
PeerChannels map[string]chan *UpdateMessage
// channelsMux keeps the mutex to access PeerChannels
channelsMux *sync.Mutex
// metrics provides method to collect application metrics
metrics telemetry.AppMetrics
}
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
return &PeersUpdateManager{
PeerChannels: make(map[string]chan *UpdateMessage),
channelsMux: &sync.Mutex{},
metrics: metrics,
}
}
// SendUpdate sends update message to the peer's channel
func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) {
start := time.Now()
var found, dropped bool
p.channelsMux.Lock()
defer func() {
p.channelsMux.Unlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountSendUpdateDuration(time.Since(start), found, dropped)
}
}()
if channel, ok := p.PeerChannels[peerID]; ok {
found = true
select {
case channel <- update:
log.Debugf("update was sent to channel for peer %s", peerID)
default:
dropped = true
log.Warnf("channel for peer %s is %d full", peerID, len(channel))
}
} else {
log.Debugf("peer %s has no channel", peerID)
}
}
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
func (p *PeersUpdateManager) CreateChannel(peerID string) chan *UpdateMessage {
start := time.Now()
closed := false
p.channelsMux.Lock()
defer func() {
p.channelsMux.Unlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountCreateChannelDuration(time.Since(start), closed)
}
}()
if channel, ok := p.PeerChannels[peerID]; ok {
closed = true
delete(p.PeerChannels, peerID)
close(channel)
}
// mbragin: todo shouldn't it be more? or configurable?
channel := make(chan *UpdateMessage, channelBufferSize)
p.PeerChannels[peerID] = channel
log.Debugf("opened updates channel for a peer %s", peerID)
return channel
}
func (p *PeersUpdateManager) closeChannel(peerID string) {
if channel, ok := p.PeerChannels[peerID]; ok {
delete(p.PeerChannels, peerID)
close(channel)
}
log.Debugf("closed updates channel of a peer %s", peerID)
}
// CloseChannels closes updates channel for each given peer
func (p *PeersUpdateManager) CloseChannels(peerIDs []string) {
start := time.Now()
p.channelsMux.Lock()
defer func() {
p.channelsMux.Unlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountCloseChannelsDuration(time.Since(start), len(peerIDs))
}
}()
for _, id := range peerIDs {
p.closeChannel(id)
}
}
// CloseChannel closes updates channel of a given peer
func (p *PeersUpdateManager) CloseChannel(peerID string) {
start := time.Now()
p.channelsMux.Lock()
defer func() {
p.channelsMux.Unlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountCloseChannelDuration(time.Since(start))
}
}()
p.closeChannel(peerID)
}
// GetAllConnectedPeers returns a copy of the connected peers map
func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} {
start := time.Now()
p.channelsMux.Lock()
m := make(map[string]struct{})
defer func() {
p.channelsMux.Unlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountGetAllConnectedPeersDuration(time.Since(start), len(m))
}
}()
for ID := range p.PeerChannels {
m[ID] = struct{}{}
}
return m
}

View File

@ -0,0 +1 @@
package routes

View File

@ -0,0 +1,115 @@
package server
import (
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// Scheduler is an interface which implementations can schedule and cancel jobs
type Scheduler interface {
Cancel(IDs []string)
Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool))
}
// MockScheduler is a mock implementation of Scheduler
type MockScheduler struct {
CancelFunc func(IDs []string)
ScheduleFunc func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool))
}
// Cancel mocks the Cancel function of the Scheduler interface
func (mock *MockScheduler) Cancel(IDs []string) {
if mock.CancelFunc != nil {
mock.CancelFunc(IDs)
return
}
log.Errorf("MockScheduler doesn't have Cancel function defined ")
}
// Schedule mocks the Schedule function of the Scheduler interface
func (mock *MockScheduler) Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
if mock.ScheduleFunc != nil {
mock.ScheduleFunc(in, ID, job)
return
}
log.Errorf("MockScheduler doesn't have Schedule function defined")
}
// DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them.
type DefaultScheduler struct {
// jobs map holds cancellation channels indexed by the job ID
jobs map[string]chan struct{}
mu *sync.Mutex
}
// NewDefaultScheduler creates an instance of a DefaultScheduler
func NewDefaultScheduler() *DefaultScheduler {
return &DefaultScheduler{
jobs: make(map[string]chan struct{}),
mu: &sync.Mutex{},
}
}
func (wm *DefaultScheduler) cancel(ID string) bool {
cancel, ok := wm.jobs[ID]
if ok {
delete(wm.jobs, ID)
select {
case cancel <- struct{}{}:
log.Debugf("cancelled scheduled job %s", ID)
default:
log.Warnf("couldn't cancel job %s because there was no routine listening on the cancel event", ID)
return false
}
}
return ok
}
// Cancel cancels the scheduled job by ID if present.
// If job wasn't found the function returns false.
func (wm *DefaultScheduler) Cancel(IDs []string) {
wm.mu.Lock()
defer wm.mu.Unlock()
for _, id := range IDs {
wm.cancel(id)
}
}
// Schedule a job to run in some time in the future. If job returns true then it will be scheduled one more time.
// If job with the provided ID already exists, a new one won't be scheduled.
func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wm.mu.Lock()
defer wm.mu.Unlock()
cancel := make(chan struct{})
if _, ok := wm.jobs[ID]; ok {
log.Debugf("couldn't schedule a job %s because it already exists. There are %d total jobs scheduled.",
ID, len(wm.jobs))
return
}
wm.jobs[ID] = cancel
log.Debugf("scheduled a job %s to run in %s. There are %d total jobs scheduled.", ID, in.String(), len(wm.jobs))
go func() {
select {
case <-time.After(in):
log.Debugf("time to do a scheduled job %s", ID)
runIn, reschedule := job()
wm.mu.Lock()
defer wm.mu.Unlock()
delete(wm.jobs, ID)
if reschedule {
go wm.Schedule(runIn, ID, job)
}
case <-cancel:
log.Debugf("stopped scheduled job %s ", ID)
wm.mu.Lock()
defer wm.mu.Unlock()
delete(wm.jobs, ID)
return
}
}()
}

View File

@ -0,0 +1 @@
package setupkey

View File

@ -0,0 +1,459 @@
package store
import (
"path/filepath"
"runtime"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
)
// SqliteStore represents an account storage backed by a Sqlite DB persisted to disk
type SqliteStore struct {
db *gorm.DB
storeFile string
accountLocks sync.Map
globalAccountLock sync.Mutex
metrics telemetry.AppMetrics
installationPK int
}
type installation struct {
ID uint `gorm:"primaryKey"`
InstallationIDValue string
}
// NewSqliteStore restores a store from the file located in the datadir
func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) {
storeStr := "store.db?cache=shared"
if runtime.GOOS == "windows" {
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
storeStr = "store.db"
}
file := filepath.Join(dataDir, storeStr)
db, err := gorm.Open(sqlite.Open(file), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
PrepareStmt: true,
})
if err != nil {
return nil, err
}
sql, err := db.DB()
if err != nil {
return nil, err
}
conns := runtime.NumCPU()
sql.SetMaxOpenConns(conns) // TODO: make it configurable
err = db.AutoMigrate(
&SetupKey{}, &Peer{}, &User{}, &PersonalAccessToken{}, &Group{}, &Rule{},
&Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{},
)
if err != nil {
return nil, err
}
return &SqliteStore{db: db, storeFile: file, metrics: metrics, installationPK: 1}, nil
}
// NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir
func NewSqliteStoreFromFileStore(filestore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, error) {
store, err := NewSqliteStore(dataDir, metrics)
if err != nil {
return nil, err
}
err = store.SaveInstallationID(filestore.InstallationID)
if err != nil {
return nil, err
}
for _, account := range filestore.GetAllAccounts() {
err := store.SaveAccount(account)
if err != nil {
return nil, err
}
}
return store, nil
}
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *SqliteStore) AcquireGlobalLock() (unlock func()) {
log.Debugf("acquiring global lock")
start := time.Now()
s.globalAccountLock.Lock()
unlock = func() {
s.globalAccountLock.Unlock()
log.Debugf("released global lock in %v", time.Since(start))
}
took := time.Since(start)
log.Debugf("took %v to acquire global lock", took)
if s.metrics != nil {
s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took)
}
return unlock
}
func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) {
log.Debugf("acquiring lock for account %s", accountID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
mtx := value.(*sync.Mutex)
mtx.Lock()
unlock = func() {
mtx.Unlock()
log.Debugf("released lock for account %s in %v", accountID, time.Since(start))
}
return unlock
}
func (s *SqliteStore) SaveAccount(account *Account) error {
start := time.Now()
for _, key := range account.SetupKeys {
account.SetupKeysG = append(account.SetupKeysG, *key)
}
for id, peer := range account.Peers {
peer.ID = id
account.PeersG = append(account.PeersG, *peer)
}
for id, user := range account.Users {
user.Id = id
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
account.UsersG = append(account.UsersG, *user)
}
for id, group := range account.Groups {
group.ID = id
account.GroupsG = append(account.GroupsG, *group)
}
for id, rule := range account.Rules {
rule.ID = id
account.RulesG = append(account.RulesG, *rule)
}
for id, route := range account.Routes {
route.ID = id
account.RoutesG = append(account.RoutesG, *route)
}
for id, ns := range account.NameServerGroups {
ns.ID = id
account.NameServerGroupsG = append(account.NameServerGroupsG, *ns)
}
err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account)
if result.Error != nil {
return result.Error
}
result = tx.
Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).Create(account)
if result.Error != nil {
return result.Error
}
return nil
})
took := time.Since(start)
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.Debugf("took %d ms to persist an account to the SQLite", took.Milliseconds())
return err
}
func (s *SqliteStore) SaveInstallationID(ID string) error {
installation := installation{InstallationIDValue: ID}
installation.ID = uint(s.installationPK)
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&installation).Error
}
func (s *SqliteStore) GetInstallationID() string {
var installation installation
if result := s.db.First(&installation, "id = ?", s.installationPK); result.Error != nil {
return ""
}
return installation.InstallationIDValue
}
func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus PeerStatus) error {
var peer Peer
result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID)
if result.Error != nil {
return status.Errorf(status.NotFound, "peer %s not found", peerID)
}
peer.Status = &peerStatus
return s.db.Save(peer).Error
}
// DeleteHashedPAT2TokenIDIndex is noop in Sqlite
func (s *SqliteStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
return nil
}
// DeleteTokenID2UserIDIndex is noop in Sqlite
func (s *SqliteStore) DeleteTokenID2UserIDIndex(tokenID string) error {
return nil
}
func (s *SqliteStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
var account Account
result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, PrivateCategory)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}
// TODO: rework to not call GetAccount
return s.GetAccount(account.Id)
}
func (s *SqliteStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
var key SetupKey
result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if key.AccountID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return s.GetAccount(key.AccountID)
}
func (s *SqliteStore) GetTokenIDByHashedToken(hashedToken string) (string, error) {
var token PersonalAccessToken
result := s.db.First(&token, "hashed_token = ?", hashedToken)
if result.Error != nil {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return token.ID, nil
}
func (s *SqliteStore) GetUserByTokenID(tokenID string) (*User, error) {
var token PersonalAccessToken
result := s.db.First(&token, "id = ?", tokenID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if token.UserID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
var user User
result = s.db.Preload("PATsG").First(&user, "id = ?", token.UserID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATsG))
for _, pat := range user.PATsG {
user.PATs[pat.ID] = pat.Copy()
}
return &user, nil
}
func (s *SqliteStore) GetAllAccounts() (all []*Account) {
var accounts []Account
result := s.db.Find(&accounts)
if result.Error != nil {
return all
}
for _, account := range accounts {
if acc, err := s.GetAccount(account.Id); err == nil {
all = append(all, acc)
}
}
return all
}
func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
var account Account
result := s.db.Model(&account).
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
Preload(clause.Associations).
First(&account, "id = ?", accountID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found")
}
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
for i, policy := range account.Policies {
var rules []*PolicyRule
err := s.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
if err != nil {
return nil, status.Errorf(status.NotFound, "account not found")
}
account.Policies[i].Rules = rules
}
account.SetupKeys = make(map[string]*SetupKey, len(account.SetupKeysG))
for _, key := range account.SetupKeysG {
account.SetupKeys[key.Key] = key.Copy()
}
account.SetupKeysG = nil
account.Peers = make(map[string]*Peer, len(account.PeersG))
for _, peer := range account.PeersG {
account.Peers[peer.ID] = peer.Copy()
}
account.PeersG = nil
account.Users = make(map[string]*User, len(account.UsersG))
for _, user := range account.UsersG {
user.PATs = make(map[string]*PersonalAccessToken, len(user.PATs))
for _, pat := range user.PATsG {
user.PATs[pat.ID] = pat.Copy()
}
account.Users[user.Id] = user.Copy()
}
account.UsersG = nil
account.Groups = make(map[string]*Group, len(account.GroupsG))
for _, group := range account.GroupsG {
account.Groups[group.ID] = group.Copy()
}
account.GroupsG = nil
account.Rules = make(map[string]*Rule, len(account.RulesG))
for _, rule := range account.RulesG {
account.Rules[rule.ID] = rule.Copy()
}
account.RulesG = nil
account.Routes = make(map[string]*route.Route, len(account.RoutesG))
for _, route := range account.RoutesG {
account.Routes[route.ID] = route.Copy()
}
account.RoutesG = nil
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
for _, ns := range account.NameServerGroupsG {
account.NameServerGroups[ns.ID] = ns.Copy()
}
account.NameServerGroupsG = nil
return &account, nil
}
func (s *SqliteStore) GetAccountByUser(userID string) (*Account, error) {
var user User
result := s.db.Select("account_id").First(&user, "id = ?", userID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if user.AccountID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return s.GetAccount(user.AccountID)
}
func (s *SqliteStore) GetAccountByPeerID(peerID string) (*Account, error) {
var peer Peer
result := s.db.Select("account_id").First(&peer, "id = ?", peerID)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if peer.AccountID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return s.GetAccount(peer.AccountID)
}
func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
var peer Peer
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if peer.AccountID == "" {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return s.GetAccount(peer.AccountID)
}
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
var user User
result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID)
if result.Error != nil {
return status.Errorf(status.NotFound, "user %s not found", userID)
}
user.LastLogin = lastLogin
return s.db.Save(user).Error
}
// Close is noop in Sqlite
func (s *SqliteStore) Close() error {
return nil
}
// GetStoreEngine returns SqliteStoreEngine
func (s *SqliteStore) GetStoreEngine() StoreEngine {
return SqliteStoreEngine
}

View File

@ -0,0 +1,98 @@
package store
import (
"fmt"
"os"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/telemetry"
)
type Store interface {
GetAllAccounts() []*Account
GetAccount(accountID string) (*Account, error)
GetAccountByUser(userID string) (*Account, error)
GetAccountByPeerPubKey(peerKey string) (*Account, error)
GetAccountByPeerID(peerID string) (*Account, error)
GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(domain string) (*Account, error)
GetTokenIDByHashedToken(secret string) (string, error)
GetUserByTokenID(tokenID string) (*User, error)
SaveAccount(account *Account) error
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
GetInstallationID() string
SaveInstallationID(ID string) error
// AcquireAccountLock should attempt to acquire account lock and return a function that releases the lock
AcquireAccountLock(accountID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock() func()
SavePeerStatus(accountID, peerID string, status PeerStatus) error
SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error
// Close should close the store persisting all unsaved data.
Close() error
// GetStoreEngine should return StoreEngine of the current store implementation.
// This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine
}
type StoreEngine string
const (
FileStoreEngine StoreEngine = "jsonfile"
SqliteStoreEngine StoreEngine = "sqlite"
)
func getStoreEngineFromEnv() StoreEngine {
// NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise rely on the config file.
kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE")
if !ok {
return FileStoreEngine
}
value := StoreEngine(strings.ToLower(kind))
if value == FileStoreEngine || value == SqliteStoreEngine {
return value
}
return FileStoreEngine
}
func NewStore(kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (Store, error) {
if kind == "" {
// fallback to env. Normally this only should be used from tests
kind = getStoreEngineFromEnv()
}
switch kind {
case FileStoreEngine:
log.Info("using JSON file store engine")
return NewFileStore(dataDir, metrics)
case SqliteStoreEngine:
log.Info("using SQLite store engine")
return NewSqliteStore(dataDir, metrics)
default:
return nil, fmt.Errorf("unsupported kind of store %s", kind)
}
}
func NewStoreFromJson(dataDir string, metrics telemetry.AppMetrics) (Store, error) {
fstore, err := NewFileStore(dataDir, nil)
if err != nil {
return nil, err
}
kind := getStoreEngineFromEnv()
switch kind {
case FileStoreEngine:
return fstore, nil
case SqliteStoreEngine:
return NewSqliteStoreFromFileStore(fstore, dataDir, metrics)
default:
return nil, fmt.Errorf("unsupported store engine %s", kind)
}
}

View File

@ -0,0 +1,125 @@
package server
import (
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"fmt"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
)
// TURNCredentialsManager used to manage TURN credentials
type TURNCredentialsManager interface {
GenerateCredentials() TURNCredentials
SetupRefresh(peerKey string)
CancelRefresh(peerKey string)
}
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
type TimeBasedAuthSecretsManager struct {
mux sync.Mutex
config *TURNConfig
updateManager *PeersUpdateManager
cancelMap map[string]chan struct{}
}
type TURNCredentials struct {
Username string
Password string
}
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, config *TURNConfig) *TimeBasedAuthSecretsManager {
return &TimeBasedAuthSecretsManager{
mux: sync.Mutex{},
config: config,
updateManager: updateManager,
cancelMap: make(map[string]chan struct{}),
}
}
// GenerateCredentials generates new time-based secret credentials - basically username is a unix timestamp and password is a HMAC hash of a timestamp with a preshared TURN secret
func (m *TimeBasedAuthSecretsManager) GenerateCredentials() TURNCredentials {
mac := hmac.New(sha1.New, []byte(m.config.Secret))
timeAuth := time.Now().Add(m.config.CredentialsTTL.Duration).Unix()
username := fmt.Sprint(timeAuth)
_, err := mac.Write([]byte(username))
if err != nil {
log.Errorln("Generating turn password failed with error: ", err)
}
bytePassword := mac.Sum(nil)
password := base64.StdEncoding.EncodeToString(bytePassword)
return TURNCredentials{
Username: username,
Password: password,
}
}
func (m *TimeBasedAuthSecretsManager) cancel(peerID string) {
if channel, ok := m.cancelMap[peerID]; ok {
close(channel)
delete(m.cancelMap, peerID)
}
}
// CancelRefresh cancels scheduled peer credentials refresh
func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) {
m.mux.Lock()
defer m.mux.Unlock()
m.cancel(peerID)
}
// SetupRefresh starts peer credentials refresh. Since credentials are expiring (TTL) it is necessary to always generate them and send to the peer.
// A goroutine is created and put into TimeBasedAuthSecretsManager.cancelMap. This routine should be cancelled if peer is gone.
func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) {
m.mux.Lock()
defer m.mux.Unlock()
m.cancel(peerID)
cancel := make(chan struct{}, 1)
m.cancelMap[peerID] = cancel
log.Debugf("starting turn refresh for %s", peerID)
go func() {
// we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL)
ticker := time.NewTicker(m.config.CredentialsTTL.Duration / 4 * 3)
for {
select {
case <-cancel:
log.Debugf("stopping turn refresh for %s", peerID)
return
case <-ticker.C:
c := m.GenerateCredentials()
var turns []*proto.ProtectedHostConfig
for _, host := range m.config.Turns {
turns = append(turns, &proto.ProtectedHostConfig{
HostConfig: &proto.HostConfig{
Uri: host.URI,
Protocol: ToResponseProto(host.Proto),
},
User: c.Username,
Password: c.Password,
})
}
update := &proto.SyncResponse{
WiretrusteeConfig: &proto.WiretrusteeConfig{
Turns: turns,
},
}
log.Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update})
}
}
}()
}

View File

@ -0,0 +1,95 @@
package users
import (
"crypto/sha256"
b64 "encoding/base64"
"fmt"
"hash/crc32"
"time"
b "github.com/hashicorp/go-secure-stdlib/base62"
"github.com/rs/xid"
"github.com/netbirdio/netbird/base62"
)
const (
// PATPrefix is the globally used, 4 char prefix for personal access tokens
PATPrefix = "nbp_"
// PATSecretLength number of characters used for the secret inside the token
PATSecretLength = 30
// PATChecksumLength number of characters used for the encoded checksum of the secret inside the token
PATChecksumLength = 6
// PATLength total number of characters used for the token
PATLength = 40
)
// PersonalAccessToken holds all information about a PAT including a hashed version of it for verification
type PersonalAccessToken struct {
ID string `gorm:"primaryKey"`
// User is a reference to Account that this object belongs
UserID string `gorm:"index"`
Name string
HashedToken string
ExpirationDate time.Time
// scope could be added in future
CreatedBy string
CreatedAt time.Time
LastUsed time.Time
}
func (t *PersonalAccessToken) Copy() *PersonalAccessToken {
return &PersonalAccessToken{
ID: t.ID,
Name: t.Name,
HashedToken: t.HashedToken,
ExpirationDate: t.ExpirationDate,
CreatedBy: t.CreatedBy,
CreatedAt: t.CreatedAt,
LastUsed: t.LastUsed,
}
}
// PersonalAccessTokenGenerated holds the new PersonalAccessToken and the plain text version of it
type PersonalAccessTokenGenerated struct {
PlainToken string
PersonalAccessToken
}
// CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User.
// Additionally, it will return the token in plain text once, to give to the user and only save a hashed version
func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) {
hashedToken, plainToken, err := generateNewToken()
if err != nil {
return nil, err
}
currentTime := time.Now()
return &PersonalAccessTokenGenerated{
PersonalAccessToken: PersonalAccessToken{
ID: xid.New().String(),
Name: name,
HashedToken: hashedToken,
ExpirationDate: currentTime.AddDate(0, 0, expirationInDays),
CreatedBy: createdBy,
CreatedAt: currentTime,
LastUsed: time.Time{},
},
PlainToken: plainToken,
}, nil
}
func generateNewToken() (string, string, error) {
secret, err := b.Random(PATSecretLength)
if err != nil {
return "", "", err
}
checksum := crc32.ChecksumIEEE([]byte(secret))
encodedChecksum := base62.Encode(checksum)
paddedChecksum := fmt.Sprintf("%06s", encodedChecksum)
plainToken := PATPrefix + secret + paddedChecksum
hashedToken := sha256.Sum256([]byte(plainToken))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
return encodedHashedToken, plainToken, nil
}

View File

@ -0,0 +1,203 @@
package users
import (
"fmt"
"strings"
"time"
"github.com/netbirdio/netbird/management/server/idp"
)
const (
UserRoleAdmin UserRole = "admin"
UserRoleUser UserRole = "user"
UserRoleUnknown UserRole = "unknown"
UserStatusActive UserStatus = "active"
UserStatusDisabled UserStatus = "disabled"
UserStatusInvited UserStatus = "invited"
UserIssuedAPI = "api"
UserIssuedIntegration = "integration"
)
type UserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Role string `json:"role"`
AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
IsServiceUser bool `json:"is_service_user"`
IsBlocked bool `json:"is_blocked"`
NonDeletable bool `json:"non_deletable"`
LastLogin time.Time `json:"last_login"`
Issued string `json:"issued"`
IntegrationReference IntegrationReference `json:"-"`
}
// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown
func StrRoleToUserRole(strRole string) UserRole {
switch strings.ToLower(strRole) {
case "admin":
return UserRoleAdmin
case "user":
return UserRoleUser
default:
return UserRoleUnknown
}
}
// UserStatus is the status of a User
type UserStatus string
// UserRole is the role of a User
type UserRole string
// IntegrationReference holds the reference to a particular integration
type IntegrationReference struct {
ID int
IntegrationType string
}
func (ir IntegrationReference) String() string {
return fmt.Sprintf("%s:%d", ir.IntegrationType, ir.ID)
}
func (ir IntegrationReference) CacheKey(path ...string) string {
if len(path) == 0 {
return ir.String()
}
return fmt.Sprintf("%s:%s", ir.String(), strings.Join(path, ":"))
}
// User represents a user of the system
type User struct {
Id string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
Role UserRole
IsServiceUser bool
// NonDeletable indicates whether the service user can be deleted
NonDeletable bool
// ServiceUserName is only set if IsServiceUser is true
ServiceUserName string
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
AutoGroups []string `gorm:"serializer:json"`
PATs map[string]*PersonalAccessToken `gorm:"-"`
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"`
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
Blocked bool
// LastLogin is the last time the user logged in to IdP
LastLogin time.Time
// Issued of the user
Issued string `gorm:"default:api"`
IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
}
// IsBlocked returns true if the user is blocked, false otherwise
func (u *User) IsBlocked() bool {
return u.Blocked
}
func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool {
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
}
// IsAdmin returns true if the user is an admin, false otherwise
func (u *User) IsAdmin() bool {
return u.Role == UserRoleAdmin
}
// ToUserInfo converts a User object to a UserInfo object.
func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
autoGroups := u.AutoGroups
if autoGroups == nil {
autoGroups = []string{}
}
if userData == nil {
return &UserInfo{
ID: u.Id,
Email: "",
Name: u.ServiceUserName,
Role: string(u.Role),
AutoGroups: u.AutoGroups,
Status: string(UserStatusActive),
IsServiceUser: u.IsServiceUser,
IsBlocked: u.Blocked,
LastLogin: u.LastLogin,
Issued: u.Issued,
}, nil
}
if userData.ID != u.Id {
return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id)
}
userStatus := UserStatusActive
if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite {
userStatus = UserStatusInvited
}
return &UserInfo{
ID: u.Id,
Email: userData.Email,
Name: userData.Name,
Role: string(u.Role),
AutoGroups: autoGroups,
Status: string(userStatus),
IsServiceUser: u.IsServiceUser,
IsBlocked: u.Blocked,
LastLogin: u.LastLogin,
Issued: u.Issued,
}, nil
}
// Copy the user
func (u *User) Copy() *User {
autoGroups := make([]string, len(u.AutoGroups))
copy(autoGroups, u.AutoGroups)
pats := make(map[string]*PersonalAccessToken, len(u.PATs))
for k, v := range u.PATs {
pats[k] = v.Copy()
}
return &User{
Id: u.Id,
AccountID: u.AccountID,
Role: u.Role,
AutoGroups: autoGroups,
IsServiceUser: u.IsServiceUser,
NonDeletable: u.NonDeletable,
ServiceUserName: u.ServiceUserName,
PATs: pats,
Blocked: u.Blocked,
LastLogin: u.LastLogin,
Issued: u.Issued,
IntegrationReference: u.IntegrationReference,
}
}
// NewUser creates a new user
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
return &User{
Id: id,
Role: role,
IsServiceUser: isServiceUser,
NonDeletable: nonDeletable,
ServiceUserName: serviceUserName,
AutoGroups: autoGroups,
Issued: issued,
}
}
// NewRegularUser creates a new user with role UserRoleUser
func NewRegularUser(id string) *User {
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
}
// NewAdminUser creates a new user with role UserRoleAdmin
func NewAdminUser(id string) *User {
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
}

View File

@ -0,0 +1,19 @@
package users
type UserManager interface {
GetUser(userID string) (User, error)
}
type DefaultUserManager struct {
repository UserRepository
}
func NewUserManager(repository UserRepository) *DefaultUserManager {
return &DefaultUserManager{
repository: repository,
}
}
func (um *DefaultUserManager) GetUser(userID string) (User, error) {
return um.repository.findUserByID(userID)
}

View File

@ -0,0 +1,5 @@
package users
type UserRepository interface {
findUserByID(userID string) (User, error)
}