diff --git a/management/refactor/api/grpc/grpcserver.go b/management/refactor/api/grpc/grpcserver.go new file mode 100644 index 000000000..341d202b6 --- /dev/null +++ b/management/refactor/api/grpc/grpcserver.go @@ -0,0 +1,641 @@ +package server + +import ( + "context" + "fmt" + "net" + "net/netip" + "strings" + "time" + + pb "github.com/golang/protobuf/proto" // nolint + "github.com/golang/protobuf/ptypes/timestamp" + "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/jwtclaims" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + internalStatus "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/telemetry" +) + +// GRPCServer an instance of a Management gRPC API server +type GRPCServer struct { + accountManager AccountManager + wgKey wgtypes.Key + proto.UnimplementedManagementServiceServer + peersUpdateManager *PeersUpdateManager + config *Config + turnCredentialsManager TURNCredentialsManager + jwtValidator *jwtclaims.JWTValidator + jwtClaimsExtractor *jwtclaims.ClaimsExtractor + appMetrics telemetry.AppMetrics + ephemeralManager *EphemeralManager +} + +// NewServer creates a new Management server +func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) { + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, err + } + + var jwtValidator *jwtclaims.JWTValidator + + if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { + jwtValidator, err = jwtclaims.NewJWTValidator( + config.HttpConfig.AuthIssuer, + config.GetAuthAudiences(), + config.HttpConfig.AuthKeysLocation, + config.HttpConfig.IdpSignKeyRefreshEnabled, + ) + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err) + } + } else { + log.Debug("unable to use http config to create new jwt middleware") + } + + if appMetrics != nil { + // update gauge based on number of connected peers which is equal to open gRPC streams + err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { + return int64(len(peersUpdateManager.peerChannels)) + }) + if err != nil { + return nil, err + } + } + + var audience, userIDClaim string + if config.HttpConfig != nil { + audience = config.HttpConfig.AuthAudience + userIDClaim = config.HttpConfig.AuthUserIDClaim + } + jwtClaimsExtractor := jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(audience), + jwtclaims.WithUserIDClaim(userIDClaim), + ) + + return &GRPCServer{ + wgKey: key, + // peerKey -> event channel + peersUpdateManager: peersUpdateManager, + accountManager: accountManager, + config: config, + turnCredentialsManager: turnCredentialsManager, + jwtValidator: jwtValidator, + jwtClaimsExtractor: jwtClaimsExtractor, + appMetrics: appMetrics, + ephemeralManager: ephemeralManager, + }, nil +} + +func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) { + // todo introduce something more meaningful with the key expiration/rotation + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountGetKeyRequest() + } + now := time.Now().Add(24 * time.Hour) + secs := int64(now.Second()) + nanos := int32(now.Nanosecond()) + expiresAt := ×tamp.Timestamp{Seconds: secs, Nanos: nanos} + + return &proto.ServerKeyResponse{ + Key: s.wgKey.PublicKey().String(), + ExpiresAt: expiresAt, + }, nil +} + +func getRealIP(ctx context.Context) net.IP { + if addr, ok := realip.FromContext(ctx); ok { + return net.IP(addr.AsSlice()) + } + return nil +} + +// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and +// notifies the connected peer of any updates (e.g. new peers under the same account) +func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { + reqStart := time.Now() + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountSyncRequest() + } + realIP := getRealIP(srv.Context()) + log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String()) + + syncReq := &proto.SyncRequest{} + peerKey, err := s.parseRequest(req, syncReq) + if err != nil { + return err + } + + peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()}) + if err != nil { + return mapError(err) + } + + err = s.sendInitialSync(peerKey, peer, netMap, srv) + if err != nil { + log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) + return err + } + + updates := s.peersUpdateManager.CreateChannel(peer.ID) + + s.ephemeralManager.OnPeerConnected(peer) + + err = s.accountManager.MarkPeerConnected(peerKey.String(), true, realIP) + if err != nil { + log.Warnf("failed marking peer as connected %s %v", peerKey, err) + } + + if s.config.TURNConfig.TimeBasedCredentials { + s.turnCredentialsManager.SetupRefresh(peer.ID) + } + + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart)) + } + + // keep a connection to the peer and send updates when available + for { + select { + // condition when there are some updates + case update, open := <-updates: + + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1) + } + + if !open { + log.Debugf("updates channel for peer %s was closed", peerKey.String()) + s.cancelPeerRoutines(peer) + return nil + } + log.Debugf("received an update for peer %s", peerKey.String()) + + encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) + if err != nil { + s.cancelPeerRoutines(peer) + return status.Errorf(codes.Internal, "failed processing update message") + } + + err = srv.SendMsg(&proto.EncryptedMessage{ + WgPubKey: s.wgKey.PublicKey().String(), + Body: encryptedResp, + }) + if err != nil { + s.cancelPeerRoutines(peer) + return status.Errorf(codes.Internal, "failed sending update message") + } + log.Debugf("sent an update to peer %s", peerKey.String()) + // condition when client <-> server connection has been terminated + case <-srv.Context().Done(): + // happens when connection drops, e.g. client disconnects + log.Debugf("stream of peer %s has been closed", peerKey.String()) + s.cancelPeerRoutines(peer) + return srv.Context().Err() + } + } +} + +func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) { + s.peersUpdateManager.CloseChannel(peer.ID) + s.turnCredentialsManager.CancelRefresh(peer.ID) + _ = s.accountManager.MarkPeerConnected(peer.Key, false, nil) + s.ephemeralManager.OnPeerDisconnected(peer) +} + +func (s *GRPCServer) validateToken(jwtToken string) (string, error) { + if s.jwtValidator == nil { + return "", status.Error(codes.Internal, "no jwt validator set") + } + + token, err := s.jwtValidator.ValidateAndParse(jwtToken) + if err != nil { + return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err) + } + claims := s.jwtClaimsExtractor.FromToken(token) + // we need to call this method because if user is new, we will automatically add it to existing or create a new account + _, _, err = s.accountManager.GetAccountFromToken(claims) + if err != nil { + return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) + } + + if err := s.accountManager.CheckUserAccessByJWTGroups(claims); err != nil { + return "", status.Errorf(codes.PermissionDenied, err.Error()) + } + + return claims.UserId, nil +} + +// maps internal internalStatus.Error to gRPC status.Error +func mapError(err error) error { + if e, ok := internalStatus.FromError(err); ok { + switch e.Type() { + case internalStatus.PermissionDenied: + return status.Errorf(codes.PermissionDenied, e.Message) + case internalStatus.Unauthorized: + return status.Errorf(codes.PermissionDenied, e.Message) + case internalStatus.Unauthenticated: + return status.Errorf(codes.PermissionDenied, e.Message) + case internalStatus.PreconditionFailed: + return status.Errorf(codes.FailedPrecondition, e.Message) + case internalStatus.NotFound: + return status.Errorf(codes.NotFound, e.Message) + default: + } + } + log.Errorf("got an unhandled error: %s", err) + return status.Errorf(codes.Internal, "failed handling request") +} + +func extractPeerMeta(loginReq *proto.LoginRequest) nbpeer.PeerSystemMeta { + osVersion := loginReq.GetMeta().GetOSVersion() + if osVersion == "" { + osVersion = loginReq.GetMeta().GetCore() + } + + networkAddresses := make([]nbpeer.NetworkAddress, 0, len(loginReq.GetMeta().GetNetworkAddresses())) + for _, addr := range loginReq.GetMeta().GetNetworkAddresses() { + netAddr, err := netip.ParsePrefix(addr.GetNetIP()) + if err != nil { + log.Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err) + continue + } + networkAddresses = append(networkAddresses, nbpeer.NetworkAddress{ + NetIP: netAddr, + Mac: addr.GetMac(), + }) + } + + return nbpeer.PeerSystemMeta{ + Hostname: loginReq.GetMeta().GetHostname(), + GoOS: loginReq.GetMeta().GetGoOS(), + Kernel: loginReq.GetMeta().GetKernel(), + Platform: loginReq.GetMeta().GetPlatform(), + OS: loginReq.GetMeta().GetOS(), + OSVersion: osVersion, + WtVersion: loginReq.GetMeta().GetWiretrusteeVersion(), + UIVersion: loginReq.GetMeta().GetUiVersion(), + KernelVersion: loginReq.GetMeta().GetKernelVersion(), + NetworkAddresses: networkAddresses, + SystemSerialNumber: loginReq.GetMeta().GetSysSerialNumber(), + SystemProductName: loginReq.GetMeta().GetSysProductName(), + SystemManufacturer: loginReq.GetMeta().GetSysManufacturer(), + Environment: nbpeer.Environment{ + Cloud: loginReq.GetMeta().GetEnvironment().GetCloud(), + Platform: loginReq.GetMeta().GetEnvironment().GetPlatform(), + }, + } +} + +func (s *GRPCServer) parseRequest(req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { + peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) + if err != nil { + log.Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey) + return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey) + } + + err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, parsed) + if err != nil { + return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message") + } + + return peerKey, nil +} + +// Login endpoint first checks whether peer is registered under any account +// In case it is, the login is successful +// In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer. +// In case of the successful registration login is also successful +func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { + reqStart := time.Now() + defer func() { + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart)) + } + }() + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequest() + } + realIP := getRealIP(ctx) + log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String()) + + loginReq := &proto.LoginRequest{} + peerKey, err := s.parseRequest(req, loginReq) + if err != nil { + return nil, err + } + + if loginReq.GetMeta() == nil { + msg := status.Errorf(codes.FailedPrecondition, + "peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP) + log.Warn(msg) + return nil, msg + } + + userID := "" + // JWT token is not always provided, it is fine for userID to be empty cuz it might be that peer is already registered, + // or it uses a setup key to register. + if loginReq.GetJwtToken() != "" { + userID, err = s.validateToken(loginReq.GetJwtToken()) + if err != nil { + log.Warnf("failed validating JWT token sent from peer %s", peerKey) + return nil, err + } + } + var sshKey []byte + if loginReq.GetPeerKeys() != nil { + sshKey = loginReq.GetPeerKeys().GetSshPubKey() + } + + peer, netMap, err := s.accountManager.LoginPeer(PeerLogin{ + WireGuardPubKey: peerKey.String(), + SSHKey: string(sshKey), + Meta: extractPeerMeta(loginReq), + UserID: userID, + SetupKey: loginReq.GetSetupKey(), + }) + + if err != nil { + log.Warnf("failed logging in peer %s", peerKey) + return nil, mapError(err) + } + + // if the login request contains setup key then it is a registration request + if loginReq.GetSetupKey() != "" { + s.ephemeralManager.OnPeerDisconnected(peer) + } + + // if peer has reached this point then it has logged in + loginResp := &proto.LoginResponse{ + WiretrusteeConfig: toWiretrusteeConfig(s.config, nil), + PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()), + } + encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) + if err != nil { + log.Warnf("failed encrypting peer %s message", peer.ID) + return nil, status.Errorf(codes.Internal, "failed logging in peer") + } + + return &proto.EncryptedMessage{ + WgPubKey: s.wgKey.PublicKey().String(), + Body: encryptedResp, + }, nil +} + +func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol { + switch configProto { + case UDP: + return proto.HostConfig_UDP + case DTLS: + return proto.HostConfig_DTLS + case HTTP: + return proto.HostConfig_HTTP + case HTTPS: + return proto.HostConfig_HTTPS + case TCP: + return proto.HostConfig_TCP + default: + panic(fmt.Errorf("unexpected config protocol type %v", configProto)) + } +} + +func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig { + if config == nil { + return nil + } + var stuns []*proto.HostConfig + for _, stun := range config.Stuns { + stuns = append(stuns, &proto.HostConfig{ + Uri: stun.URI, + Protocol: ToResponseProto(stun.Proto), + }) + } + var turns []*proto.ProtectedHostConfig + for _, turn := range config.TURNConfig.Turns { + var username string + var password string + if turnCredentials != nil { + username = turnCredentials.Username + password = turnCredentials.Password + } else { + username = turn.Username + password = turn.Password + } + turns = append(turns, &proto.ProtectedHostConfig{ + HostConfig: &proto.HostConfig{ + Uri: turn.URI, + Protocol: ToResponseProto(turn.Proto), + }, + User: username, + Password: password, + }) + } + + return &proto.WiretrusteeConfig{ + Stuns: stuns, + Turns: turns, + Signal: &proto.HostConfig{ + Uri: config.Signal.URI, + Protocol: ToResponseProto(config.Signal.Proto), + }, + } +} + +func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.PeerConfig { + netmask, _ := network.Net.Mask.Size() + fqdn := peer.FQDN(dnsName) + return &proto.PeerConfig{ + Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network + SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, + Fqdn: fqdn, + } +} + +func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { + remotePeers := []*proto.RemotePeerConfig{} + for _, rPeer := range peers { + fqdn := rPeer.FQDN(dnsName) + remotePeers = append(remotePeers, &proto.RemotePeerConfig{ + WgPubKey: rPeer.Key, + AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)}, + SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, + Fqdn: fqdn, + }) + } + return remotePeers +} + +func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse { + wtConfig := toWiretrusteeConfig(config, turnCredentials) + + pConfig := toPeerConfig(peer, networkMap.Network, dnsName) + + remotePeers := toRemotePeerConfig(networkMap.Peers, dnsName) + + routesUpdate := toProtocolRoutes(networkMap.Routes) + + dnsUpdate := toProtocolDNSConfig(networkMap.DNSConfig) + + offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName) + + firewallRules := toProtocolFirewallRules(networkMap.FirewallRules) + + return &proto.SyncResponse{ + WiretrusteeConfig: wtConfig, + PeerConfig: pConfig, + RemotePeers: remotePeers, + RemotePeersIsEmpty: len(remotePeers) == 0, + NetworkMap: &proto.NetworkMap{ + Serial: networkMap.Network.CurrentSerial(), + PeerConfig: pConfig, + RemotePeers: remotePeers, + OfflinePeers: offlinePeers, + RemotePeersIsEmpty: len(remotePeers) == 0, + Routes: routesUpdate, + DNSConfig: dnsUpdate, + FirewallRules: firewallRules, + FirewallRulesIsEmpty: len(firewallRules) == 0, + }, + } +} + +// IsHealthy indicates whether the service is healthy +func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) { + return &proto.Empty{}, nil +} + +// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization +func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error { + // make secret time based TURN credentials optional + var turnCredentials *TURNCredentials + if s.config.TURNConfig.TimeBasedCredentials { + creds := s.turnCredentialsManager.GenerateCredentials() + turnCredentials = &creds + } else { + turnCredentials = nil + } + plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain()) + + encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) + if err != nil { + return status.Errorf(codes.Internal, "error handling request") + } + + err = srv.Send(&proto.EncryptedMessage{ + WgPubKey: s.wgKey.PublicKey().String(), + Body: encryptedResp, + }) + + if err != nil { + log.Errorf("failed sending SyncResponse %v", err) + return status.Errorf(codes.Internal, "error handling request") + } + + return nil +} + +// GetDeviceAuthorizationFlow returns a device authorization flow information +// This is used for initiating an Oauth 2 device authorization grant flow +// which will be used by our clients to Login +func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { + peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) + if err != nil { + errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey) + log.Warn(errMSG) + return nil, status.Error(codes.InvalidArgument, errMSG) + } + + err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{}) + if err != nil { + errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey) + log.Warn(errMSG) + return nil, status.Error(codes.InvalidArgument, errMSG) + } + + if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(NONE) { + return nil, status.Error(codes.NotFound, "no device authorization flow information available") + } + + provider, ok := proto.DeviceAuthorizationFlowProvider_value[strings.ToUpper(s.config.DeviceAuthorizationFlow.Provider)] + if !ok { + return nil, status.Errorf(codes.InvalidArgument, "no provider found in the protocol for %s", s.config.DeviceAuthorizationFlow.Provider) + } + + flowInfoResp := &proto.DeviceAuthorizationFlow{ + Provider: proto.DeviceAuthorizationFlowProvider(provider), + ProviderConfig: &proto.ProviderConfig{ + ClientID: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientID, + ClientSecret: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientSecret, + Domain: s.config.DeviceAuthorizationFlow.ProviderConfig.Domain, + Audience: s.config.DeviceAuthorizationFlow.ProviderConfig.Audience, + DeviceAuthEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint, + TokenEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint, + Scope: s.config.DeviceAuthorizationFlow.ProviderConfig.Scope, + UseIDToken: s.config.DeviceAuthorizationFlow.ProviderConfig.UseIDToken, + }, + } + + encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp) + if err != nil { + return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information") + } + + return &proto.EncryptedMessage{ + WgPubKey: s.wgKey.PublicKey().String(), + Body: encryptedResp, + }, nil +} + +// GetPKCEAuthorizationFlow returns a pkce authorization flow information +// This is used for initiating an Oauth 2 pkce authorization grant flow +// which will be used by our clients to Login +func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { + peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) + if err != nil { + errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey) + log.Warn(errMSG) + return nil, status.Error(codes.InvalidArgument, errMSG) + } + + err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{}) + if err != nil { + errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey) + log.Warn(errMSG) + return nil, status.Error(codes.InvalidArgument, errMSG) + } + + if s.config.PKCEAuthorizationFlow == nil { + return nil, status.Error(codes.NotFound, "no pkce authorization flow information available") + } + + flowInfoResp := &proto.PKCEAuthorizationFlow{ + ProviderConfig: &proto.ProviderConfig{ + Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience, + ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID, + ClientSecret: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientSecret, + TokenEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint, + AuthorizationEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint, + Scope: s.config.PKCEAuthorizationFlow.ProviderConfig.Scope, + RedirectURLs: s.config.PKCEAuthorizationFlow.ProviderConfig.RedirectURLs, + UseIDToken: s.config.PKCEAuthorizationFlow.ProviderConfig.UseIDToken, + }, + } + + encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp) + if err != nil { + return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information") + } + + return &proto.EncryptedMessage{ + WgPubKey: s.wgKey.PublicKey().String(), + Body: encryptedResp, + }, nil +} diff --git a/management/refactor/api/http/specs/api.yaml b/management/refactor/api/http/specs/api.yaml new file mode 100644 index 000000000..e69de29bb diff --git a/management/refactor/mesh/controller.go b/management/refactor/mesh/controller.go new file mode 100644 index 000000000..751097ca1 --- /dev/null +++ b/management/refactor/mesh/controller.go @@ -0,0 +1,157 @@ +package mesh + +import ( + "github.com/netbirdio/management-integrations/integrations" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/refactor/peers" + "github.com/netbirdio/netbird/management/refactor/policies" + "github.com/netbirdio/netbird/management/refactor/settings" + "github.com/netbirdio/netbird/management/refactor/store" + "github.com/netbirdio/netbird/management/refactor/users" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/status" +) + +type Controller interface { + LoginPeer() + SyncPeer() +} + +type DefaultController struct { + store store.Store + peersManager peers.Manager + userManager users.Manager + policiesManager policies.Manager + settingsManager settings.Manager +} + +func NewDefaultController() *DefaultController { + storeStore, _ := store.NewDefaultStore(store.SqliteStoreEngine, "", nil) + peersManager := peers.NewDefaultManager(storeStore, nil) + settingsManager := settings.NewDefaultManager(storeStore) + usersManager := users.NewDefaultManager(storeStore, peersManager) + policiesManager := policies.NewDefaultManager(storeStore, peersManager) + + peersManager, settingsManager, usersManager, policiesManager, storeStore = integrations.InjectCloud(peersManager, policiesManager, settingsManager, usersManager, storeStore) + + return &DefaultController{ + store: storeStore, + peersManager: peersManager, + userManager: usersManager, + policiesManager: policiesManager, + settingsManager: settingsManager, + } +} + +func (c *DefaultController) LoginPeer(login peers.PeerLogin) { + + peer, err := c.peersManager.GetPeerByPubKey(login.WireGuardPubKey) + if err != nil { + return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered") + } + + if peer.AddedWithSSOLogin() { + user, err := c.userManager.GetUser(peer.GetUserID()) + if err != nil { + return nil, nil, err + } + if user.IsBlocked() { + return nil, nil, status.Errorf(status.PermissionDenied, "user is blocked") + } + } + + account, err := pm.accountManager.GetAccount(peer.AccountID) + if err != nil { + return nil, nil, err + } + + // this flag prevents unnecessary calls to the persistent store. + shouldStorePeer := false + updateRemotePeers := false + if peerLoginExpired(peer, account) { + err = checkAuth(login.UserID, peer) + if err != nil { + return nil, nil, err + } + // If peer was expired before and if it reached this point, it is re-authenticated. + // UserID is present, meaning that JWT validation passed successfully in the API layer. + peer.UpdateLastLogin() + updateRemotePeers = true + shouldStorePeer = true + + pm.eventsManager.StoreEvent(login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(pm.accountManager.GetDNSDomain())) + } + + if peer.UpdateMetaIfNew(login.Meta) { + shouldStorePeer = true + } + + if peer.CheckAndUpdatePeerSSHKey(login.SSHKey) { + shouldStorePeer = true + } + + if shouldStorePeer { + err := pm.repository.updatePeer(peer) + if err != nil { + return nil, nil, err + } + } + + if updateRemotePeers { + am.updateAccountPeers(account) + } + return peer, account.GetPeerNetworkMap(peer.ID, pm.accountManager.GetDNSDomain()), nil +} + +func (c *DefaultController) SyncPeer() { + +} + +func (c *DefaultController) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap { + peer, err := c.peersManager.GetNetworkPeerByID(peerID) + if err != nil { + return &NetworkMap{ + Network: a.Network.Copy(), + } + } + + aclPeers, firewallRules := c.policiesManager.GetAccessiblePeersAndFirewallRules(peerID) + // exclude expired peers + var peersToConnect []*peers.Peer + var expiredPeers []*peers.Peer + accSettings, _ := c.settingsManager.GetSettings(peer.GetAccountID()) + for _, p := range aclPeers { + expired, _ := p.LoginExpired(accSettings.GetPeerLoginExpiration()) + if accSettings.GetPeerLoginExpirationEnabled() && expired { + expiredPeers = append(expiredPeers, p) + continue + } + peersToConnect = append(peersToConnect, p) + } + + routesUpdate := a.getRoutesToSync(peerID, peersToConnect) + + dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) + dnsUpdate := nbdns.Config{ + ServiceEnable: dnsManagementStatus, + } + + if dnsManagementStatus { + var zones []nbdns.CustomZone + peersCustomZone := getPeersCustomZone(a, dnsDomain) + if peersCustomZone.Domain != "" { + zones = append(zones, peersCustomZone) + } + dnsUpdate.CustomZones = zones + dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) + } + + return &NetworkMap{ + Peers: peersToConnect, + Network: a.Network.Copy(), + Routes: routesUpdate, + DNSConfig: dnsUpdate, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, + } +} diff --git a/management/refactor/mesh/network.go b/management/refactor/mesh/network.go new file mode 100644 index 000000000..a7fe69806 --- /dev/null +++ b/management/refactor/mesh/network.go @@ -0,0 +1,15 @@ +package mesh + +import ( + nbdns "github.com/netbirdio/netbird/dns" + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +type NetworkMap struct { + Peers []*nbpeer.Peer + Network *Network + Routes []*route.Route + DNSConfig nbdns.Config + OfflinePeers []*nbpeer.Peer + FirewallRules []*FirewallRule +} diff --git a/management/refactor/peers/manager.go b/management/refactor/peers/manager.go new file mode 100644 index 000000000..2a641d122 --- /dev/null +++ b/management/refactor/peers/manager.go @@ -0,0 +1,50 @@ +package peers + +import ( + "github.com/netbirdio/netbird/management/refactor/settings" +) + +type Manager interface { + GetPeerByPubKey(pubKey string) (Peer, error) + GetPeerByID(id string) (Peer, error) + GetNetworkPeerByID(id string) (Peer, error) + GetNetworkPeersInAccount(id string) ([]Peer, error) +} + +type DefaultManager struct { + repository repository + settingsManager settings.Manager +} + +func NewDefaultManager(repository repository, settingsManager settings.Manager) *DefaultManager { + return &DefaultManager{ + repository: repository, + settingsManager: settingsManager, + } +} + +func (dm *DefaultManager) GetNetworkPeerByID(id string) (Peer, error) { + return dm.repository.FindPeerByID(id) +} + +func (dm *DefaultManager) GetNetworkPeersInAccount(id string) ([]Peer, error) { + defaultPeers, err := dm.repository.FindAllPeersInAccount(id) + if err != nil { + return nil, err + } + + peers := make([]Peer, len(defaultPeers)) + for _, dp := range defaultPeers { + peers = append(peers, dp) + } + + return peers, nil +} + +func (dm *DefaultManager) GetPeerByPubKey(pubKey string) (Peer, error) { + return dm.repository.FindPeerByPubKey(pubKey) +} + +func (dm *DefaultManager) GetPeerByID(id string) (Peer, error) { + return dm.repository.FindPeerByID(id) +} diff --git a/management/refactor/peers/peer.go b/management/refactor/peers/peer.go new file mode 100644 index 000000000..82b6fa15e --- /dev/null +++ b/management/refactor/peers/peer.go @@ -0,0 +1,244 @@ +package peers + +import ( + "fmt" + "net" + "time" +) + +type Peer interface { + GetID() string + SetID(string) + GetAccountID() string + SetAccountID(string) + GetKey() string + SetKey(string) + GetSetupKey() string + SetSetupKey(string) + GetIP() net.IP + SetIP(net.IP) + GetName() string + SetName(string) + GetDNSLabel() string + SetDNSLabel(string) + GetUserID() string + SetUserID(string) + GetSSHKey() string + SetSSHKey(string) + GetSSHEnabled() bool + SetSSHEnabled(bool) + AddedWithSSOLogin() bool + UpdateMetaIfNew(meta PeerSystemMeta) bool + MarkLoginExpired(expired bool) + FQDN(dnsDomain string) string + EventMeta(dnsDomain string) map[string]any + LoginExpired(expiresIn time.Duration) (bool, time.Duration) +} + +// Peer represents a machine connected to the network. +// The Peer is a WireGuard peer identified by a public key +type DefaultPeer struct { + // ID is an internal ID of the peer + ID string `gorm:"primaryKey"` + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index;uniqueIndex:idx_peers_account_id_ip"` + // WireGuard public key + Key string `gorm:"index"` + // A setup key this peer was registered with + SetupKey string + // IP address of the Peer + IP net.IP `gorm:"uniqueIndex:idx_peers_account_id_ip"` + // Name is peer's name (machine name) + Name string + // DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's + // domain to the peer label. e.g. peer-dns-label.netbird.cloud + DNSLabel string + // The user ID that registered the peer + UserID string + // SSHKey is a public SSH key of the peer + SSHKey string + // SSHEnabled indicates whether SSH server is enabled on the peer + SSHEnabled bool + // LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login. + // Works with LastLogin + LoginExpirationEnabled bool + // LastLogin the time when peer performed last login operation + LastLogin time.Time + // Indicate ephemeral peer attribute + Ephemeral bool +} + +// PeerLogin used as a data object between the gRPC API and AccountManager on Login request. +type PeerLogin struct { + // WireGuardPubKey is a peers WireGuard public key + WireGuardPubKey string + // SSHKey is a peer's ssh key. Can be empty (e.g., old version do not provide it, or this feature is disabled) + SSHKey string + // Meta is the system information passed by peer, must be always present. + Meta PeerSystemMeta + // UserID indicates that JWT was used to log in, and it was valid. Can be empty when SetupKey is used or auth is not required. + UserID string + // AccountID indicates that JWT was used to log in, and it was valid. Can be empty when SetupKey is used or auth is not required. + AccountID string + // SetupKey references to a server.SetupKey to log in. Can be empty when UserID is used or auth is not required. + SetupKey string +} + +// AddedWithSSOLogin indicates whether this peer has been added with an SSO login by a user. +func (p *DefaultPeer) AddedWithSSOLogin() bool { + return p.UserID != "" +} + +// UpdateMetaIfNew updates peer's system metadata if new information is provided +// returns true if meta was updated, false otherwise +func (p *DefaultPeer) UpdateMetaIfNew(meta PeerSystemMeta) bool { + // Avoid overwriting UIVersion if the update was triggered sole by the CLI client + if meta.UIVersion == "" { + meta.UIVersion = p.Meta.UIVersion + } + + if p.Meta.isEqual(meta) { + return false + } + p.Meta = meta + return true +} + +// MarkLoginExpired marks peer's status expired or not +func (p *DefaultPeer) MarkLoginExpired(expired bool) { + newStatus := p.Status.Copy() + newStatus.LoginExpired = expired + if expired { + newStatus.Connected = false + } + p.Status = newStatus +} + +// LoginExpired indicates whether the peer's login has expired or not. +// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired. +// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired). +// Login expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property. +// Login expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled. +// Only peers added by interactive SSO login can be expired. +func (p *DefaultPeer) LoginExpired(expiresIn time.Duration) (bool, time.Duration) { + if !p.AddedWithSSOLogin() || !p.LoginExpirationEnabled { + return false, 0 + } + expiresAt := p.LastLogin.Add(expiresIn) + now := time.Now() + timeLeft := expiresAt.Sub(now) + return timeLeft <= 0, timeLeft +} + +// FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain +func (p *DefaultPeer) FQDN(dnsDomain string) string { + if dnsDomain == "" { + return "" + } + return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain) +} + +// EventMeta returns activity event meta related to the peer +func (p *DefaultPeer) EventMeta(dnsDomain string) map[string]any { + return map[string]any{"name": p.Name, "fqdn": p.FQDN(dnsDomain), "ip": p.IP, "created_at": p.CreatedAt} +} + +func (p *DefaultPeer) GetID() string { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) SetID(s string) { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) GetAccountID() string { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) SetAccountID(s string) { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) GetKey() string { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) SetKey(s string) { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) GetSetupKey() string { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) SetSetupKey(s string) { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) GetIP() net.IP { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) SetIP(ip net.IP) { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) GetName() string { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) SetName(s string) { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) GetDNSLabel() string { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) SetDNSLabel(s string) { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) GetUserID() string { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) SetUserID(s string) { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) GetSSHKey() string { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) SetSSHKey(s string) { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) GetSSHEnabled() bool { + // TODO implement me + panic("implement me") +} + +func (p *DefaultPeer) SetSSHEnabled(b bool) { + // TODO implement me + panic("implement me") +} diff --git a/management/refactor/peers/repository.go b/management/refactor/peers/repository.go new file mode 100644 index 000000000..ac8d527f5 --- /dev/null +++ b/management/refactor/peers/repository.go @@ -0,0 +1,8 @@ +package peers + +type repository interface { + FindPeerByPubKey(pubKey string) (*Peer, error) + FindPeerByID(id string) (*Peer, error) + FindAllPeersInAccount(id string) ([]*Peer, error) + UpdatePeer(peer Peer) error +} diff --git a/management/refactor/policies/manager.go b/management/refactor/policies/manager.go new file mode 100644 index 000000000..8f9d4f12f --- /dev/null +++ b/management/refactor/policies/manager.go @@ -0,0 +1,30 @@ +package policies + +import "github.com/netbirdio/netbird/management/refactor/peers" + +type Manager interface { + GetAccessiblePeersAndFirewallRules(peerID string) (peers []peers.Peer, firewallRules []*FirewallRule) +} + +type DefaultManager struct { + repository repository + peerManager peers.Manager +} + +func NewDefaultManager(repository repository, peerManager peers.Manager) *DefaultManager { + return &DefaultManager{ + repository: repository, + peerManager: peerManager, + } +} + +func (dm *DefaultManager) GetAccessiblePeersAndFirewallRules(peerID string) (peers []peers.Peer, firewallRules []*FirewallRule) { + peer, err := dm.peerManager.GetPeerByID(peerID) + if err != nil { + return nil, nil + } + + peers, err = dm.peerManager.GetNetworkPeersInAccount(peer.GetAccountID()) + + return peers, nil +} diff --git a/management/refactor/policies/policy.go b/management/refactor/policies/policy.go new file mode 100644 index 000000000..e2dda9756 --- /dev/null +++ b/management/refactor/policies/policy.go @@ -0,0 +1,7 @@ +package policies + +type Policy interface { +} + +type DefaultPolicy struct { +} diff --git a/management/refactor/policies/repository.go b/management/refactor/policies/repository.go new file mode 100644 index 000000000..d244a84ad --- /dev/null +++ b/management/refactor/policies/repository.go @@ -0,0 +1,4 @@ +package policies + +type repository interface { +} diff --git a/management/refactor/settings/manager.go b/management/refactor/settings/manager.go new file mode 100644 index 000000000..3ec3f42d7 --- /dev/null +++ b/management/refactor/settings/manager.go @@ -0,0 +1,19 @@ +package settings + +type Manager interface { + GetSettings(accountID string) (Settings, error) +} + +type DefaultManager struct { + repository repository +} + +func NewDefaultManager(repository repository) *DefaultManager { + return &DefaultManager{ + repository: repository, + } +} + +func (dm *DefaultManager) GetSettings(accountID string) (Settings, error) { + return dm.repository.FindSettings(accountID) +} diff --git a/management/refactor/settings/repository.go b/management/refactor/settings/repository.go new file mode 100644 index 000000000..978eb435a --- /dev/null +++ b/management/refactor/settings/repository.go @@ -0,0 +1,5 @@ +package settings + +type repository interface { + FindSettings(accountID string) (Settings, error) +} diff --git a/management/refactor/settings/settings.go b/management/refactor/settings/settings.go new file mode 100644 index 000000000..d640614d1 --- /dev/null +++ b/management/refactor/settings/settings.go @@ -0,0 +1,34 @@ +package settings + +import "time" + +type Settings interface { + GetLicense() string + GetPeerLoginExpiration() time.Duration + SetPeerLoginExpiration(duration time.Duration) + GetPeerLoginExpirationEnabled() bool + SetPeerLoginExpirationEnabled(bool) +} + +type DefaultSettings struct { +} + +func (s *DefaultSettings) GetLicense() string { + return "selfhosted" +} + +func (s *DefaultSettings) GetPeerLoginExpiration() time.Duration { + return 0 +} + +func (s *DefaultSettings) SetPeerLoginExpiration(duration time.Duration) { + +} + +func (s *DefaultSettings) GetPeerLoginExpirationEnabled() bool { + return false +} + +func (s *DefaultSettings) SetPeerLoginExpirationEnabled(bool) { + +} diff --git a/management/refactor/store/postgres_store.go b/management/refactor/store/postgres_store.go new file mode 100644 index 000000000..78721e6eb --- /dev/null +++ b/management/refactor/store/postgres_store.go @@ -0,0 +1,51 @@ +package store + +import ( + "github.com/netbirdio/netbird/management/refactor/peers" + "github.com/netbirdio/netbird/management/refactor/settings" +) + +const ( + PostgresStoreEngine StoreEngine = "postgres" +) + +type DefaultPostgresStore struct { +} + +func (s *DefaultPostgresStore) FindSettings(accountID string) (*settings.Settings, error) { + // TODO implement me + panic("implement me") +} + +func (s *DefaultPostgresStore) FindPeerByPubKey(pubKey string) (*peers.Peer, error) { + // TODO implement me + panic("implement me") +} + +func (s *DefaultPostgresStore) FindPeerByID(id string) (*peers.Peer, error) { + // TODO implement me + panic("implement me") +} + +func (s *DefaultPostgresStore) FindAllPeersInAccount(id string) ([]*peers.Peer, error) { + // TODO implement me + panic("implement me") +} + +func (s *DefaultPostgresStore) UpdatePeer(peer peers.Peer) error { + // TODO implement me + panic("implement me") +} + +func (s *DefaultPostgresStore) GetLicense() string { + // TODO implement me + panic("implement me") +} + +func NewDefaultPostgresStore() *DefaultPostgresStore { + return &DefaultPostgresStore{} +} + +func (s *DefaultPostgresStore) GetEngine() StoreEngine { + return PostgresStoreEngine +} diff --git a/management/refactor/store/sqlite_store.go b/management/refactor/store/sqlite_store.go new file mode 100644 index 000000000..241090535 --- /dev/null +++ b/management/refactor/store/sqlite_store.go @@ -0,0 +1,166 @@ +package store + +import ( + "path/filepath" + "runtime" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + + "github.com/netbirdio/netbird/management/refactor/peers" + "github.com/netbirdio/netbird/management/refactor/settings" + "github.com/netbirdio/netbird/management/server/telemetry" +) + +const ( + SqliteStoreEngine StoreEngine = "sqlite" +) + +// SqliteStore represents an account storage backed by a Sqlite DB persisted to disk +type DefaultSqliteStore struct { + db *gorm.DB + storeFile string + accountLocks sync.Map + globalAccountLock sync.Mutex + metrics telemetry.AppMetrics + installationPK int +} + +func (s *DefaultSqliteStore) FindSettings(accountID string) (*settings.Settings, error) { + // TODO implement me + panic("implement me") +} + +func (s *DefaultSqliteStore) FindPeerByPubKey(pubKey string) (*peers.Peer, error) { + // TODO implement me + panic("implement me") +} + +func (s *DefaultSqliteStore) FindPeerByID(id string) (*peers.Peer, error) { + // TODO implement me + panic("implement me") +} + +func (s *DefaultSqliteStore) FindAllPeersInAccount(id string) ([]*peers.Peer, error) { + // TODO implement me + panic("implement me") +} + +func (s *DefaultSqliteStore) UpdatePeer(peer peers.Peer) error { + // TODO implement me + panic("implement me") +} + +type installation struct { + ID uint `gorm:"primaryKey"` + InstallationIDValue string +} + +// NewSqliteStore restores a store from the file located in the datadir +func NewDefaultSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*DefaultSqliteStore, error) { + storeStr := "store.db?cache=shared" + if runtime.GOOS == "windows" { + // Vo avoid `The process cannot access the file because it is being used by another process` on Windows + storeStr = "store.db" + } + + file := filepath.Join(dataDir, storeStr) + db, err := gorm.Open(sqlite.Open(file), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + PrepareStmt: true, + }) + if err != nil { + return nil, err + } + + sql, err := db.DB() + if err != nil { + return nil, err + } + conns := runtime.NumCPU() + sql.SetMaxOpenConns(conns) // TODO: make it configurable + + // err = db.AutoMigrate( + // &SetupKey{}, &Peer{}, &User{}, &PersonalAccessToken{}, &Group{}, &Rule{}, + // &Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, + // &installation{}, + // ) + // if err != nil { + // return nil, err + // } + + return &DefaultSqliteStore{db: db, storeFile: file, metrics: metrics, installationPK: 1}, nil +} + +func (s *DefaultSqliteStore) GetLicense() string { + // TODO implement me + panic("implement me") +} + +// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock +func (s *DefaultSqliteStore) AcquireGlobalLock() (unlock func()) { + log.Debugf("acquiring global lock") + start := time.Now() + s.globalAccountLock.Lock() + + unlock = func() { + s.globalAccountLock.Unlock() + log.Debugf("released global lock in %v", time.Since(start)) + } + + took := time.Since(start) + log.Debugf("took %v to acquire global lock", took) + if s.metrics != nil { + s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took) + } + + return unlock +} + +func (s *DefaultSqliteStore) AcquireAccountLock(accountID string) (unlock func()) { + log.Debugf("acquiring lock for account %s", accountID) + + start := time.Now() + value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) + mtx := value.(*sync.Mutex) + mtx.Lock() + + unlock = func() { + mtx.Unlock() + log.Debugf("released lock for account %s in %v", accountID, time.Since(start)) + } + + return unlock +} + +func (s *DefaultSqliteStore) SaveInstallationID(ID string) error { + installation := installation{InstallationIDValue: ID} + installation.ID = uint(s.installationPK) + + return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&installation).Error +} + +func (s *DefaultSqliteStore) GetInstallationID() string { + var installation installation + + if result := s.db.First(&installation, "id = ?", s.installationPK); result.Error != nil { + return "" + } + + return installation.InstallationIDValue +} + +// Close is noop in Sqlite +func (s *DefaultSqliteStore) Close() error { + return nil +} + +// GetStoreEngine returns SqliteStoreEngine +func (s *DefaultSqliteStore) GetStoreEngine() StoreEngine { + return SqliteStoreEngine +} diff --git a/management/refactor/store/store.go b/management/refactor/store/store.go new file mode 100644 index 000000000..c7f10508e --- /dev/null +++ b/management/refactor/store/store.go @@ -0,0 +1,66 @@ +package store + +import ( + "fmt" + "os" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/refactor/peers" + "github.com/netbirdio/netbird/management/refactor/settings" + "github.com/netbirdio/netbird/management/server/telemetry" +) + +type Store interface { + GetLicense() string + FindPeerByPubKey(pubKey string) (*peers.Peer, error) + FindPeerByID(id string) (*peers.Peer, error) + FindAllPeersInAccount(id string) ([]*peers.Peer, error) + UpdatePeer(peer peers.Peer) error + FindSettings(accountID string) (settings.Settings, error) +} + +type DefaultStore interface { + GetLicense() string + FindPeerByPubKey(pubKey string) (*peers.Peer, error) + FindPeerByID(id string) (*peers.Peer, error) + FindAllPeersInAccount(id string) ([]*peers.Peer, error) + UpdatePeer(peer peers.Peer) error + FindSettings(accountID string) (settings.Settings, error) +} + +type StoreEngine string + +func getStoreEngineFromEnv() StoreEngine { + // NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise rely on the config file. + kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE") + if !ok { + return SqliteStoreEngine + } + + value := StoreEngine(strings.ToLower(kind)) + + if value == PostgresStoreEngine || value == SqliteStoreEngine { + return value + } + + return SqliteStoreEngine +} + +func NewDefaultStore(kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (DefaultStore, error) { + if kind == "" { + // fallback to env. Normally this only should be used from tests + kind = getStoreEngineFromEnv() + } + switch kind { + case PostgresStoreEngine: + log.Info("using JSON file store engine") + return NewDefaultPostgresStore(), nil + case SqliteStoreEngine: + log.Info("using SQLite store engine") + return NewDefaultSqliteStore(dataDir, metrics) + default: + return nil, fmt.Errorf("unsupported kind of store %s", kind) + } +} diff --git a/management/refactor/users/manager.go b/management/refactor/users/manager.go new file mode 100644 index 000000000..cce038aca --- /dev/null +++ b/management/refactor/users/manager.go @@ -0,0 +1,24 @@ +package users + +import "github.com/netbirdio/netbird/management/refactor/peers" + +type Manager interface { + GetUser(id string) (User, error) +} + +type DefaultManager struct { + repository repository + peerManager peers.Manager +} + +func NewDefaultManager(repository repository, peerManager peers.Manager) *DefaultManager { + return &DefaultManager{ + repository: repository, + peerManager: peerManager, + } +} + +func (d DefaultManager) GetUser(id string) (User, error) { + // TODO implement me + panic("implement me") +} diff --git a/management/refactor/users/repository.go b/management/refactor/users/repository.go new file mode 100644 index 000000000..bb737d0c2 --- /dev/null +++ b/management/refactor/users/repository.go @@ -0,0 +1,4 @@ +package users + +type repository interface { +} diff --git a/management/refactor/users/user.go b/management/refactor/users/user.go new file mode 100644 index 000000000..bd12c65df --- /dev/null +++ b/management/refactor/users/user.go @@ -0,0 +1,35 @@ +package users + +import "time" + +// UserRole is the role of a User +type UserRole string + +type User interface { + IsBlocked() bool +} + +// User represents a user of the system +type DefaultUser struct { + Id string `gorm:"primaryKey"` + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + Role UserRole + IsServiceUser bool + // NonDeletable indicates whether the service user can be deleted + NonDeletable bool + // ServiceUserName is only set if IsServiceUser is true + ServiceUserName string + // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user + AutoGroups []string `gorm:"serializer:json"` + // Blocked indicates whether the user is blocked. Blocked users can't use the system. + Blocked bool + // LastLogin is the last time the user logged in to IdP + LastLogin time.Time + // Issued of the user + Issued string `gorm:"default:api"` +} + +func (u *DefaultUser) IsBlocked() bool { + return u.Blocked +}