diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 430a7da44..b4c2791d8 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -76,7 +76,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir) + store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir) if err != nil { t.Fatal(err) } @@ -87,13 +87,13 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste if err != nil { return nil, nil } - iv, _ := integrations.NewIntegratedValidator(eventStore) - accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv) + iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) + accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv) if err != nil { t.Fatal(err) } turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) - mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 0db0ab74c..6c6f79d07 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -174,7 +174,7 @@ func TestEngine_SSH(t *testing.T) { t.Fatal(err) } - //time.Sleep(250 * time.Millisecond) + // time.Sleep(250 * time.Millisecond) assert.NotNil(t, engine.sshServer) assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=") @@ -1057,7 +1057,7 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir) + store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir) if err != nil { return nil, "", err } @@ -1068,13 +1068,13 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) if err != nil { return nil, "", err } - ia, _ := integrations.NewIntegratedValidator(eventStore) - accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) + ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) + accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) if err != nil { return nil, "", err } turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) - mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index 46fc9fa8e..b19e4615f 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -108,7 +108,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve return nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir) + store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir) if err != nil { return nil, "", err } @@ -119,13 +119,13 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve if err != nil { return nil, "", err } - ia, _ := integrations.NewIntegratedValidator(eventStore) - accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) + ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) + accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) if err != nil { return nil, "", err } turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) - mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil) if err != nil { return nil, "", err } diff --git a/formatter/hook.go b/formatter/hook.go index c3aa77fb3..12f27e67d 100644 --- a/formatter/hook.go +++ b/formatter/hook.go @@ -7,6 +7,18 @@ import ( "strings" "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/context" +) + +type ExecutionContext string + +const ( + ExecutionContextKey = "executionContext" + + HTTPSource ExecutionContext = "HTTP" + GRPCSource ExecutionContext = "GRPC" + SystemSource ExecutionContext = "SYSTEM" ) // ContextHook is a custom hook for add the source information for the entry @@ -30,6 +42,27 @@ func (hook ContextHook) Levels() []logrus.Level { func (hook ContextHook) Fire(entry *logrus.Entry) error { src := hook.parseSrc(entry.Caller.File) entry.Data["source"] = fmt.Sprintf("%s:%v", src, entry.Caller.Line) + + if entry.Context == nil { + return nil + } + + source, ok := entry.Context.Value(ExecutionContextKey).(ExecutionContext) + if !ok { + return nil + } + + entry.Data["context"] = source + + switch source { + case HTTPSource: + addHTTPFields(entry) + case GRPCSource: + addGRPCFields(entry) + case SystemSource: + addSystemFields(entry) + } + return nil } @@ -59,3 +92,42 @@ func (hook ContextHook) parseSrc(filePath string) string { file := path.Base(filePath) return fmt.Sprintf("%s/%s", pkg, file) } + +func addHTTPFields(entry *logrus.Entry) { + if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok { + entry.Data[context.RequestIDKey] = ctxReqID + } + if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok { + entry.Data[context.AccountIDKey] = ctxAccountID + } + if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok { + entry.Data[context.UserIDKey] = ctxInitiatorID + } +} + +func addGRPCFields(entry *logrus.Entry) { + if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok { + entry.Data[context.RequestIDKey] = ctxReqID + } + if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok { + entry.Data[context.AccountIDKey] = ctxAccountID + } + if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok { + entry.Data[context.PeerIDKey] = ctxDeviceID + } +} + +func addSystemFields(entry *logrus.Entry) { + if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok { + entry.Data[context.RequestIDKey] = ctxReqID + } + if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok { + entry.Data[context.UserIDKey] = ctxInitiatorID + } + if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok { + entry.Data[context.AccountIDKey] = ctxAccountID + } + if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok { + entry.Data[context.PeerIDKey] = ctxDeviceID + } +} diff --git a/formatter/set.go b/formatter/set.go index cceeef860..f9ccef601 100644 --- a/formatter/set.go +++ b/formatter/set.go @@ -1,6 +1,8 @@ package formatter -import "github.com/sirupsen/logrus" +import ( + "github.com/sirupsen/logrus" +) // SetTextFormatter set the text formatter for given logger. func SetTextFormatter(logger *logrus.Logger) { @@ -9,6 +11,13 @@ func SetTextFormatter(logger *logrus.Logger) { logger.AddHook(NewContextHook()) } +// SetJSONFormatter set the JSON formatter for given logger. +func SetJSONFormatter(logger *logrus.Logger) { + logger.Formatter = &logrus.JSONFormatter{} + logger.ReportCaller = true + logger.AddHook(NewContextHook()) +} + // SetLogcatFormatter set the logcat formatter for given logger. func SetLogcatFormatter(logger *logrus.Logger) { logger.Formatter = NewLogcatFormatter() diff --git a/go.mod b/go.mod index bf05fb0c9..78390042d 100644 --- a/go.mod +++ b/go.mod @@ -44,7 +44,6 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.6.0 github.com/google/gopacket v1.1.19 - github.com/google/martian/v3 v3.0.0 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 @@ -58,7 +57,7 @@ require ( github.com/miekg/dns v1.1.43 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd + github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/go.sum b/go.sum index 7e67a9d99..1a74b2664 100644 --- a/go.sum +++ b/go.sum @@ -209,8 +209,6 @@ github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSN github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= -github.com/google/martian/v3 v3.0.0 h1:pMen7vLs8nvgEYhywH3KDWJIJTeEr2ULsVWHWYHQyBs= -github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A= github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc= github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= @@ -335,8 +333,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd h1:IzGGIJMpz07aPs3R6/4sxZv63JoCMddftLpVodUK+Ec= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM= @@ -565,7 +563,6 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191004110552-13f9640d40b9/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= diff --git a/management/client/client_test.go b/management/client/client_test.go index 001a89b73..2774f2b59 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -62,7 +62,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir) + store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir) if err != nil { t.Fatal(err) } @@ -70,13 +70,13 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { peersUpdateManager := mgmt.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} - ia, _ := integrations.NewIntegratedValidator(eventStore) - accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) + ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) + accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) if err != nil { t.Fatal(err) } turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) - mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index 366935802..b87c386c6 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -20,6 +20,7 @@ import ( "time" "github.com/google/uuid" + grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2" "github.com/miekg/dns" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -35,8 +36,10 @@ import ( "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/formatter" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" + nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" httpapi "github.com/netbirdio/netbird/management/server/http" "github.com/netbirdio/netbird/management/server/idp" @@ -77,6 +80,10 @@ var ( Short: "start NetBird Management Server", PreRunE: func(cmd *cobra.Command, args []string) error { flag.Parse() + + //nolint + ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource) + err := util.InitLog(logLevel, logFile) if err != nil { return fmt.Errorf("failed initializing log %v", err) @@ -85,7 +92,7 @@ var ( // detect whether user specified a port userPort := cmd.Flag("port").Changed - config, err = loadMgmtConfig(mgmtConfig) + config, err = loadMgmtConfig(ctx, mgmtConfig) if err != nil { return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err) } @@ -116,6 +123,11 @@ var ( return nil }, RunE: func(cmd *cobra.Command, args []string) error { + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + //nolint + ctx = context.WithValue(ctx, formatter.ExecutionContextKey, formatter.SystemSource) + err := handleRebrand(cmd) if err != nil { return fmt.Errorf("failed to migrate files %v", err) @@ -131,11 +143,11 @@ var ( if err != nil { return err } - err = appMetrics.Expose(mgmtMetricsPort, "/metrics") + err = appMetrics.Expose(ctx, mgmtMetricsPort, "/metrics") if err != nil { return err } - store, err := server.NewStore(config.StoreConfig.Engine, config.Datadir, appMetrics) + store, err := server.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics) if err != nil { return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err) } @@ -143,7 +155,7 @@ var ( var idpManager idp.Manager if config.IdpManagerConfig != nil { - idpManager, err = idp.NewManager(*config.IdpManagerConfig, appMetrics) + idpManager, err = idp.NewManager(ctx, *config.IdpManagerConfig, appMetrics) if err != nil { return fmt.Errorf("failed retrieving a new idp manager with err: %v", err) } @@ -152,32 +164,32 @@ var ( if disableSingleAccMode { mgmtSingleAccModeDomain = "" } - eventStore, key, err := integrations.InitEventStore(config.Datadir, config.DataStoreEncryptionKey) + eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey) if err != nil { return fmt.Errorf("failed to initialize database: %s", err) } if config.DataStoreEncryptionKey != key { - log.Infof("update config with activity store key") + log.WithContext(ctx).Infof("update config with activity store key") config.DataStoreEncryptionKey = key - err := updateMgmtConfig(mgmtConfig, config) + err := updateMgmtConfig(ctx, mgmtConfig, config) if err != nil { return fmt.Errorf("failed to write out store encryption key: %s", err) } } - geo, err := geolocation.NewGeolocation(config.Datadir) + geo, err := geolocation.NewGeolocation(ctx, config.Datadir) if err != nil { - log.Warnf("could not initialize geo location service: %v, we proceed without geo support", err) + log.WithContext(ctx).Warnf("could not initialize geo location service: %v, we proceed without geo support", err) } else { - log.Infof("geo location service has been initialized from %s", config.Datadir) + log.WithContext(ctx).Infof("geo location service has been initialized from %s", config.Datadir) } - integratedPeerValidator, err := integrations.NewIntegratedValidator(eventStore) + integratedPeerValidator, err := integrations.NewIntegratedValidator(ctx, eventStore) if err != nil { return fmt.Errorf("failed to initialize integrated peer validator: %v", err) } - accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, + accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator) if err != nil { return fmt.Errorf("failed to build default manager: %v", err) @@ -188,13 +200,13 @@ var ( trustedPeers := config.ReverseProxy.TrustedPeers defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")} if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) { - log.Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.") + log.WithContext(ctx).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.") trustedPeers = defaultTrustedPeers } trustedHTTPProxies := config.ReverseProxy.TrustedHTTPProxies trustedProxiesCount := config.ReverseProxy.TrustedHTTPProxiesCount if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 { - log.Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " + + log.WithContext(ctx).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " + "This is not recommended way to extract X-Forwarded-For. Consider using one of these options.") } realipOpts := []realip.Option{ @@ -206,8 +218,8 @@ var ( gRPCOpts := []grpc.ServerOption{ grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp), - grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...)), - grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...)), + grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor), + grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor), } var certManager *autocert.Manager @@ -224,7 +236,7 @@ var ( } else if config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "" { tlsConfig, err = loadTLSConfig(config.HttpConfig.CertFile, config.HttpConfig.CertKey) if err != nil { - log.Errorf("cannot load TLS credentials: %v", err) + log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err) return err } transportCredentials := credentials.NewTLS(tlsConfig) @@ -233,6 +245,7 @@ var ( } jwtValidator, err := jwtclaims.NewJWTValidator( + ctx, config.HttpConfig.AuthIssuer, config.GetAuthAudiences(), config.HttpConfig.AuthKeysLocation, @@ -249,26 +262,24 @@ var ( KeysLocation: config.HttpConfig.AuthKeysLocation, } - ctx, cancel := context.WithCancel(cmd.Context()) - defer cancel() httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) if err != nil { return fmt.Errorf("failed creating HTTP API handler: %v", err) } ephemeralManager := server.NewEphemeralManager(store, accountManager) - ephemeralManager.LoadInitialPeers() + ephemeralManager.LoadInitialPeers(ctx) gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager) + srv, err := server.NewServer(ctx, config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager) if err != nil { return fmt.Errorf("failed creating gRPC API handler: %v", err) } mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv) - installationID, err := getInstallationID(store) + installationID, err := getInstallationID(ctx, store) if err != nil { - log.Errorf("cannot load TLS credentials: %v", err) + log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err) return err } @@ -278,18 +289,18 @@ var ( idpManager = config.IdpManagerConfig.ManagerType } metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager, idpManager) - go metricsWorker.Run() + go metricsWorker.Run(ctx) } var compatListener net.Listener if mgmtPort != ManagementLegacyPort { // The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it // are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073. - compatListener, err = serveGRPC(gRPCAPIHandler, ManagementLegacyPort) + compatListener, err = serveGRPC(ctx, gRPCAPIHandler, ManagementLegacyPort) if err != nil { return err } - log.Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String()) + log.WithContext(ctx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String()) } rootHandler := handlerFunc(gRPCAPIHandler, httpAPIHandler) @@ -306,8 +317,8 @@ var ( if err != nil { return fmt.Errorf("failed creating TLS listener on port %d: %v", mgmtPort, err) } - log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String()) - serveHTTP(cml, certManager.HTTPHandler(nil)) + log.WithContext(ctx).Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String()) + serveHTTP(ctx, cml, certManager.HTTPHandler(nil)) } } else if tlsConfig != nil { listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", mgmtPort), tlsConfig) @@ -321,14 +332,14 @@ var ( } } - log.Infof("management server version %s", version.NetbirdVersion()) - log.Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String()) - serveGRPCWithHTTP(listener, rootHandler, tlsEnabled) + log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion()) + log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String()) + serveGRPCWithHTTP(ctx, listener, rootHandler, tlsEnabled) SetupCloseHandler() <-stopCh - integratedPeerValidator.Stop() + integratedPeerValidator.Stop(ctx) if geo != nil { _ = geo.Stop() } @@ -339,39 +350,68 @@ var ( _ = certManager.Listener().Close() } gRPCAPIHandler.Stop() - _ = store.Close() - _ = eventStore.Close() - log.Infof("stopped Management Service") + _ = store.Close(ctx) + _ = eventStore.Close(ctx) + log.WithContext(ctx).Infof("stopped Management Service") return nil }, } ) -func notifyStop(msg string) { +func unaryInterceptor( + ctx context.Context, + req interface{}, + info *grpc.UnaryServerInfo, + handler grpc.UnaryHandler, +) (interface{}, error) { + reqID := uuid.New().String() + //nolint + ctx = context.WithValue(ctx, formatter.ExecutionContextKey, formatter.GRPCSource) + //nolint + ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID) + return handler(ctx, req) +} + +func streamInterceptor( + srv interface{}, + ss grpc.ServerStream, + info *grpc.StreamServerInfo, + handler grpc.StreamHandler, +) error { + reqID := uuid.New().String() + wrapped := grpcMiddleware.WrapServerStream(ss) + //nolint + ctx := context.WithValue(ss.Context(), formatter.ExecutionContextKey, formatter.GRPCSource) + //nolint + wrapped.WrappedContext = context.WithValue(ctx, nbContext.RequestIDKey, reqID) + return handler(srv, wrapped) +} + +func notifyStop(ctx context.Context, msg string) { select { case stopCh <- 1: - log.Error(msg) + log.WithContext(ctx).Error(msg) default: // stop has been already called, nothing to report } } -func getInstallationID(store server.Store) (string, error) { +func getInstallationID(ctx context.Context, store server.Store) (string, error) { installationID := store.GetInstallationID() if installationID != "" { return installationID, nil } installationID = strings.ToUpper(uuid.New().String()) - err := store.SaveInstallationID(installationID) + err := store.SaveInstallationID(ctx, installationID) if err != nil { return "", err } return installationID, nil } -func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) { +func serveGRPC(ctx context.Context, grpcServer *grpc.Server, port int) (net.Listener, error) { listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) if err != nil { return nil, err @@ -379,22 +419,22 @@ func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) { go func() { err := grpcServer.Serve(listener) if err != nil { - notifyStop(fmt.Sprintf("failed running gRPC server on port %d: %v", port, err)) + notifyStop(ctx, fmt.Sprintf("failed running gRPC server on port %d: %v", port, err)) } }() return listener, nil } -func serveHTTP(httpListener net.Listener, handler http.Handler) { +func serveHTTP(ctx context.Context, httpListener net.Listener, handler http.Handler) { go func() { err := http.Serve(httpListener, handler) if err != nil { - notifyStop(fmt.Sprintf("failed running HTTP server: %v", err)) + notifyStop(ctx, fmt.Sprintf("failed running HTTP server: %v", err)) } }() } -func serveGRPCWithHTTP(listener net.Listener, handler http.Handler, tlsEnabled bool) { +func serveGRPCWithHTTP(ctx context.Context, listener net.Listener, handler http.Handler, tlsEnabled bool) { go func() { var err error if tlsEnabled { @@ -411,7 +451,7 @@ func serveGRPCWithHTTP(listener net.Listener, handler http.Handler, tlsEnabled b if err != nil { select { case stopCh <- 1: - log.Errorf("failed to serve HTTP and gRPC server: %v", err) + log.WithContext(ctx).Errorf("failed to serve HTTP and gRPC server: %v", err) default: // stop has been already called, nothing to report } @@ -431,7 +471,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle }) } -func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) { +func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) { loadedConfig := &server.Config{} _, err := util.ReadJson(mgmtConfigPath, loadedConfig) if err != nil { @@ -452,26 +492,26 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) { oidcEndpoint := loadedConfig.HttpConfig.OIDCConfigEndpoint if oidcEndpoint != "" { // if OIDCConfigEndpoint is specified, we can load DeviceAuthEndpoint and TokenEndpoint automatically - log.Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint) - oidcConfig, err := fetchOIDCConfig(oidcEndpoint) + log.WithContext(ctx).Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint) + oidcConfig, err := fetchOIDCConfig(ctx, oidcEndpoint) if err != nil { return nil, err } - log.Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint) + log.WithContext(ctx).Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint) - log.Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s", + log.WithContext(ctx).Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s", oidcConfig.Issuer, loadedConfig.HttpConfig.AuthIssuer) loadedConfig.HttpConfig.AuthIssuer = oidcConfig.Issuer - log.Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s", + log.WithContext(ctx).Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s", oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation) loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(server.NONE)) { - log.Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s", + log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s", oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint) loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint - log.Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s", + log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s", oidcConfig.DeviceAuthEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint) loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint = oidcConfig.DeviceAuthEndpoint @@ -479,7 +519,7 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) { if err != nil { return nil, err } - log.Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s", + log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s", u.Host, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain) loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host @@ -489,10 +529,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) { } if loadedConfig.PKCEAuthorizationFlow != nil { - log.Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s", + log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s", oidcConfig.TokenEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint) loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint - log.Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s", + log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s", oidcConfig.AuthorizationEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint) loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint } @@ -501,8 +541,8 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) { return loadedConfig, err } -func updateMgmtConfig(path string, config *server.Config) error { - return util.DirectWriteJson(path, config) +func updateMgmtConfig(ctx context.Context, path string, config *server.Config) error { + return util.DirectWriteJson(ctx, path, config) } // OIDCConfigResponse used for parsing OIDC config response @@ -515,7 +555,7 @@ type OIDCConfigResponse struct { } // fetchOIDCConfig fetches OIDC configuration from the IDP -func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) { +func fetchOIDCConfig(ctx context.Context, oidcEndpoint string) (OIDCConfigResponse, error) { res, err := http.Get(oidcEndpoint) if err != nil { return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration from endpoint %s %v", oidcEndpoint, err) @@ -524,7 +564,7 @@ func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) { defer func() { err := res.Body.Close() if err != nil { - log.Debugf("failed closing response body %v", err) + log.WithContext(ctx).Debugf("failed closing response body %v", err) } }() diff --git a/management/cmd/migration_up.go b/management/cmd/migration_up.go index 89adfce55..7aa11f0c9 100644 --- a/management/cmd/migration_up.go +++ b/management/cmd/migration_up.go @@ -1,13 +1,16 @@ package cmd import ( + "context" "flag" "fmt" - "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/util" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/formatter" + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/util" ) var shortUp = "Migrate JSON file store to SQLite store. Please make a backup of the JSON file before running this command." @@ -26,10 +29,13 @@ var upCmd = &cobra.Command{ return fmt.Errorf("failed initializing log %v", err) } - if err := server.MigrateFileStoreToSqlite(mgmtDataDir); err != nil { + //nolint + ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource) + + if err := server.MigrateFileStoreToSqlite(ctx, mgmtDataDir); err != nil { return err } - log.Info("Migration finished successfully") + log.WithContext(ctx).Info("Migration finished successfully") return nil }, diff --git a/management/server/account.go b/management/server/account.go index 845325226..27c21e402 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -59,84 +59,85 @@ func cacheEntryExpiration() time.Duration { } type AccountManager interface { - GetOrCreateAccountByUser(userId, domain string) (*Account, error) - CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, + GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error) + CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) - SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error) - CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error) - DeleteUser(accountID, initiatorUserID string, targetUserID string) error - InviteUser(accountID string, initiatorUserID string, targetUserID string) error - ListSetupKeys(accountID, userID string) ([]*SetupKey, error) - SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error) - SaveOrAddUser(accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) - GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) - GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) - GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) - CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error - GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) - DeleteAccount(accountID, userID string) error - MarkPATUsed(tokenID string) error - GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) - ListUsers(accountID string) ([]*User, error) - GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) - MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *Account) error - DeletePeer(accountID, peerID, userID string) error - UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - GetNetworkMap(peerID string) (*NetworkMap, error) - GetPeerNetwork(peerID string) (*Network, error) - AddPeer(setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) - CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) - DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error - GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) - GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) - UpdatePeerSSHKey(peerID string, sshKey string) error - GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) - GetGroup(accountId, groupID, userID string) (*nbgroup.Group, error) - GetAllGroups(accountID, userID string) ([]*nbgroup.Group, error) - GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) - SaveGroup(accountID, userID string, group *nbgroup.Group) error - DeleteGroup(accountId, userId, groupID string) error - ListGroups(accountId string) ([]*nbgroup.Group, error) - GroupAddPeer(accountId, groupID, peerID string) error - GroupDeletePeer(accountId, groupID, peerID string) error - GetPolicy(accountID, policyID, userID string) (*Policy, error) - SavePolicy(accountID, userID string, policy *Policy) error - DeletePolicy(accountID, policyID, userID string) error - ListPolicies(accountID, userID string) ([]*Policy, error) - GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) - CreateRoute(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) - SaveRoute(accountID, userID string, route *route.Route) error - DeleteRoute(accountID string, routeID route.ID, userID string) error - ListRoutes(accountID, userID string) ([]*route.Route, error) - GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) - CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) - SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error - DeleteNameServerGroup(accountID, nsGroupID, userID string) error - ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) + SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error) + CreateUser(ctx context.Context, accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error) + DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error + InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error + ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) + SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) + SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) + GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) + GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) + GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) + CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error + GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) + DeleteAccount(ctx context.Context, accountID, userID string) error + MarkPATUsed(ctx context.Context, tokenID string) error + GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) + ListUsers(ctx context.Context, accountID string) ([]*User, error) + GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) + MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *Account) error + DeletePeer(ctx context.Context, accountID, peerID, userID string) error + UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) + GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) + GetPeerNetwork(ctx context.Context, peerID string) (*Network, error) + AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) + CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) + DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error + GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) + GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) + UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error + GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) + GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error) + GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) + GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) + SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error + DeleteGroup(ctx context.Context, accountId, userId, groupID string) error + ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error) + GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error + GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error + GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) + SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error + DeletePolicy(ctx context.Context, accountID, policyID, userID string) error + ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) + GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) + CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error + DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error + ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) + GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) + CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) + SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error + DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error + ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) GetDNSDomain() string - StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) - GetEvents(accountID, userID string) ([]*activity.Event, error) - GetDNSSettings(accountID string, userID string) (*DNSSettings, error) - SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error - GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error) - LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API - SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API + StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) + GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) + GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) + SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error + GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) + UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) + LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API + SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager - GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecks(accountID, userID string, postureChecks *posture.Checks) error - DeletePostureChecks(accountID, postureChecksID, userID string) error - ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) + GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) + SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error + ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManager() idp.Manager - UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error - GroupValidation(accountId string, groups []string) (bool, error) + UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error + GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(account *Account) (map[string]struct{}, error) - SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) - CancelPeerRoutines(peer *nbpeer.Peer) error - SyncPeerMeta(peerPubKey string, meta nbpeer.PeerSystemMeta) error + SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) + CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error + SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) + GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) } type DefaultAccountManager struct { @@ -274,8 +275,8 @@ type UserInfo struct { // getRoutesToSync returns the enabled routes for the peer ID and the routes // from the ACL peers that have distribution groups associated with the peer ID. // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*route.Route { - routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID) +func (a *Account) getRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route { + routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID) peerRoutesMembership := make(lookupMap) for _, r := range append(routes, peerDisabledRoutes...) { peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} @@ -283,7 +284,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou groupListMap := a.getPeerGroups(peerID) for _, peer := range aclPeers { - activeRoutes, _ := a.getRoutingPeerRoutes(peer.ID) + activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID) groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap) filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) routes = append(routes, filteredRoutes...) @@ -322,11 +323,11 @@ func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap looku // getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. // If the given is not a routing peer, then the lists are empty. -func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { +func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { peer := a.GetPeer(peerID) if peer == nil { - log.Errorf("peer %s that doesn't exist under account %s", peerID, a.Id) + log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id) return enabledRoutes, disabledRoutes } @@ -355,7 +356,7 @@ func (a *Account) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Ro for _, groupID := range r.PeerGroups { group := a.GetGroup(groupID) if group == nil { - log.Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) + log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) continue } for _, id := range group.Peers { @@ -399,7 +400,7 @@ func (a *Account) GetGroup(groupID string) *nbgroup.Group { } // GetPeerNetworkMap returns a group by ID if exists, nil otherwise -func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap { +func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap { peer := a.Peers[peerID] if peer == nil { return &NetworkMap{ @@ -413,7 +414,7 @@ func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap } } - aclPeers, firewallRules := a.getPeerConnectionResources(peerID, validatedPeersMap) + aclPeers, firewallRules := a.getPeerConnectionResources(ctx, peerID, validatedPeersMap) // exclude expired peers var peersToConnect []*nbpeer.Peer var expiredPeers []*nbpeer.Peer @@ -426,7 +427,7 @@ func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap peersToConnect = append(peersToConnect, p) } - routesUpdate := a.getRoutesToSync(peerID, peersToConnect) + routesUpdate := a.getRoutesToSync(ctx, peerID, peersToConnect) dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) dnsUpdate := nbdns.Config{ @@ -435,7 +436,7 @@ func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap if dnsManagementStatus { var zones []nbdns.CustomZone - peersCustomZone := getPeersCustomZone(a, dnsDomain) + peersCustomZone := getPeersCustomZone(ctx, a, dnsDomain) if peersCustomZone.Domain != "" { zones = append(zones, peersCustomZone) } @@ -872,7 +873,7 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { } // BuildManager creates a new DefaultAccountManager with a provided Store -func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, +func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation, userDeleteFromIDPEnabled bool, integratedPeerValidator integrated_validator.IntegratedValidator, @@ -891,7 +892,7 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, integratedPeerValidator: integratedPeerValidator, } - allAccounts := store.GetAllAccounts() + allAccounts := store.GetAllAccounts(ctx) // enable single account mode only if configured by user and number of existing accounts is not grater than 1 am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1 if am.singleAccountMode { @@ -899,9 +900,9 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) } am.singleAccountModeDomain = singleAccountModeDomain - log.Infof("single account mode enabled, accounts number %d", len(allAccounts)) + log.WithContext(ctx).Infof("single account mode enabled, accounts number %d", len(allAccounts)) } else { - log.Infof("single account mode disabled, accounts number %d", len(allAccounts)) + log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", len(allAccounts)) } // if account doesn't have a default group @@ -919,7 +920,7 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage } if shouldSave { - err = store.SaveAccount(account) + err = store.SaveAccount(ctx, account) if err != nil { return nil, err } @@ -937,16 +938,18 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage if !isNil(am.idpManager) { go func() { - err := am.warmupIDPCache() + err := am.warmupIDPCache(ctx) if err != nil { - log.Warnf("failed warming up cache due to error: %v", err) + log.WithContext(ctx).Warnf("failed warming up cache due to error: %v", err) // todo retry? return } }() } - am.integratedPeerValidator.SetPeerInvalidationListener(am.onPeersInvalidated) + am.integratedPeerValidator.SetPeerInvalidationListener(func(accountID string) { + am.onPeersInvalidated(ctx, accountID) + }) return am, nil } @@ -963,7 +966,7 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager { // Only users with role UserRoleAdmin can update the account. // User that performs the update has to belong to the account. // Returns an updated Account -func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error) { +func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -973,10 +976,10 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -990,7 +993,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") } - err = am.integratedPeerValidator.ValidateExtraSettings(newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID) + err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID) if err != nil { return nil, err } @@ -1000,21 +1003,21 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, event := activity.AccountPeerLoginExpirationEnabled if !newSettings.PeerLoginExpirationEnabled { event = activity.AccountPeerLoginExpirationDisabled - am.peerLoginExpiry.Cancel([]string{accountID}) + am.peerLoginExpiry.Cancel(ctx, []string{accountID}) } else { - am.checkAndSchedulePeerLoginExpiration(account) + am.checkAndSchedulePeerLoginExpiration(ctx, account) } - am.StoreEvent(userID, accountID, accountID, event, nil) + am.StoreEvent(ctx, userID, accountID, accountID, event, nil) } if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { - am.StoreEvent(userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) - am.checkAndSchedulePeerLoginExpiration(account) + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) + am.checkAndSchedulePeerLoginExpiration(ctx, account) } updatedAccount := account.UpdateSettings(newSettings) - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } @@ -1022,14 +1025,14 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, return updatedAccount, nil } -func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() (time.Duration, bool) { +func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { - log.Errorf("failed getting account %s expiring peers", accountID) + log.WithContext(ctx).Errorf("failed getting account %s expiring peers", accountID) return account.GetNextPeerExpiration() } @@ -1039,10 +1042,10 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() peerIDs = append(peerIDs, peer.ID) } - log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) + log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) - if err := am.expireAndUpdatePeers(account, expiredPeers); err != nil { - log.Errorf("failed updating account peers while expiring peers for account %s", account.Id) + if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { + log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", account.Id) return account.GetNextPeerExpiration() } @@ -1050,28 +1053,28 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() } } -func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(account *Account) { - am.peerLoginExpiry.Cancel([]string{account.Id}) +func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *Account) { + am.peerLoginExpiry.Cancel(ctx, []string{account.Id}) if nextRun, ok := account.GetNextPeerExpiration(); ok { - go am.peerLoginExpiry.Schedule(nextRun, account.Id, am.peerLoginExpirationJob(account.Id)) + go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id)) } } // newAccount creates a new Account with a generated ID and generated default setup keys. // If ID is already in use (due to collision) we try one more time before returning error -func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, error) { +func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) { for i := 0; i < 2; i++ { accountId := xid.New().String() - _, err := am.Store.GetAccount(accountId) + _, err := am.Store.GetAccount(ctx, accountId) statusErr, _ := status.FromError(err) switch { case err == nil: - log.Warnf("an account with ID already exists, retrying...") + log.WithContext(ctx).Warnf("an account with ID already exists, retrying...") continue case statusErr.Type() == status.NotFound: - newAccount := newAccountWithId(accountId, userID, domain) - am.StoreEvent(userID, newAccount.Id, accountId, activity.AccountCreated, nil) + newAccount := newAccountWithId(ctx, accountId, userID, domain) + am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil) return newAccount, nil default: return nil, err @@ -1081,12 +1084,12 @@ func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, er return nil, status.Errorf(status.Internal, "error while creating new account") } -func (am *DefaultAccountManager) warmupIDPCache() error { - userData, err := am.idpManager.GetAllAccounts() +func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error { + userData, err := am.idpManager.GetAllAccounts(ctx) if err != nil { return err } - log.Infof("%d entries received from IdP management", len(userData)) + log.WithContext(ctx).Infof("%d entries received from IdP management", len(userData)) // If the Identity Provider does not support writing AppMetadata, // in cases like this, we expect it to return all users in an "unset" field. @@ -1094,7 +1097,7 @@ func (am *DefaultAccountManager) warmupIDPCache() error { // update their AppMetadata with the AccountID. if unsetData, ok := userData[idp.UnsetAccountID]; ok { for _, user := range unsetData { - accountID, err := am.Store.GetAccountByUser(user.ID) + accountID, err := am.Store.GetAccountByUser(ctx, user.ID) if err == nil { data := userData[accountID.Id] if data == nil { @@ -1117,15 +1120,15 @@ func (am *DefaultAccountManager) warmupIDPCache() error { return err } } - log.Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(userData)) + log.WithContext(ctx).Infof("warmed up IDP cache with %d entries for %d accounts", rcvdUsers, len(userData)) return nil } // DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner -func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -1151,42 +1154,42 @@ func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error { continue } - deleteUserErr := am.deleteRegularUser(account, userID, otherUser.Id) + deleteUserErr := am.deleteRegularUser(ctx, account, userID, otherUser.Id) if deleteUserErr != nil { return deleteUserErr } } - err = am.deleteRegularUser(account, userID, userID) + err = am.deleteRegularUser(ctx, account, userID, userID) if err != nil { - log.Errorf("failed deleting user %s. error: %s", userID, err) + log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err) return err } - err = am.Store.DeleteAccount(account) + err = am.Store.DeleteAccount(ctx, account) if err != nil { - log.Errorf("failed deleting account %s. error: %s", accountID, err) + log.WithContext(ctx).Errorf("failed deleting account %s. error: %s", accountID, err) return err } // cancel peer login expiry job - am.peerLoginExpiry.Cancel([]string{account.Id}) + am.peerLoginExpiry.Cancel(ctx, []string{account.Id}) - log.Debugf("account %s deleted", accountID) + log.WithContext(ctx).Debugf("account %s deleted", accountID) return nil } // GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and // userID doesn't have an account associated with it, one account is created // domain is used to create a new account if no account is found -func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) { +func (am *DefaultAccountManager) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) { if accountID != "" { - return am.Store.GetAccount(accountID) + return am.Store.GetAccount(ctx, accountID) } else if userID != "" { - account, err := am.GetOrCreateAccountByUser(userID, domain) + account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) if err != nil { return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID) } - err = am.addAccountIDToIDPAppMeta(userID, account) + err = am.addAccountIDToIDPAppMeta(ctx, userID, account) if err != nil { return nil, err } @@ -1201,28 +1204,28 @@ func isNil(i idp.Manager) bool { } // addAccountIDToIDPAppMeta update user's app metadata in idp manager -func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account *Account) error { +func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, account *Account) error { if !isNil(am.idpManager) { // user can be nil if it wasn't found (e.g., just created) - user, err := am.lookupUserInCache(userID, account) + user, err := am.lookupUserInCache(ctx, userID, account) if err != nil { return err } if user != nil && user.AppMetadata.WTAccountID == account.Id { // it was already set, so we skip the unnecessary update - log.Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s", + log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s", account.Id, userID) return nil } - err = am.idpManager.UpdateUserAppMetadata(userID, idp.AppMetadata{WTAccountID: account.Id}) + err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: account.Id}) if err != nil { return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err) } // refresh cache to reflect the update - _, err = am.refreshCache(account.Id) + _, err = am.refreshCache(ctx, account.Id) if err != nil { return err } @@ -1230,20 +1233,20 @@ func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account return nil } -func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interface{}) ([]*idp.UserData, error) { - log.Debugf("account %s not found in cache, reloading", accountID) +func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID interface{}) ([]*idp.UserData, error) { + log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID) accountIDString := fmt.Sprintf("%v", accountID) - account, err := am.Store.GetAccount(accountIDString) + account, err := am.Store.GetAccount(ctx, accountIDString) if err != nil { return nil, err } - userData, err := am.idpManager.GetAccount(accountIDString) + userData, err := am.idpManager.GetAccount(ctx, accountIDString) if err != nil { return nil, err } - log.Debugf("%d entries received from IdP management", len(userData)) + log.WithContext(ctx).Debugf("%d entries received from IdP management", len(userData)) dataMap := make(map[string]*idp.UserData, len(userData)) for _, datum := range userData { @@ -1257,7 +1260,7 @@ func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interf } datum, ok := dataMap[user.Id] if !ok { - log.Warnf("user %s not found in IDP", user.Id) + log.WithContext(ctx).Warnf("user %s not found in IDP", user.Id) continue } matchedUserData = append(matchedUserData, datum) @@ -1265,8 +1268,8 @@ func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interf return matchedUserData, nil } -func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountID string) (*idp.UserData, error) { - data, err := am.getAccountFromCache(accountID, false) +func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, email string, accountID string) (*idp.UserData, error) { + data, err := am.getAccountFromCache(ctx, accountID, false) if err != nil { return nil, err } @@ -1281,7 +1284,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountI } // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil -func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) { +func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *Account) (*idp.UserData, error) { users := make(map[string]userLoggedInOnce, len(account.Users)) // ignore service users and users provisioned by integrations than are never logged in for _, user := range account.Users { @@ -1293,8 +1296,8 @@ func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Accou } users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero()) } - log.Debugf("looking up user %s of account %s in cache", userID, account.Id) - userData, err := am.lookupCache(users, account.Id) + log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, account.Id) + userData, err := am.lookupCache(ctx, users, account.Id) if err != nil { return nil, err } @@ -1309,25 +1312,25 @@ func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Accou // or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta user, err := account.FindUser(userID) if err != nil { - log.Errorf("failed finding user %s in account %s", userID, account.Id) + log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, account.Id) return nil, err } key := user.IntegrationReference.CacheKey(account.Id, userID) ud, err := am.externalCacheManager.Get(am.ctx, key) if err != nil { - log.Debugf("failed to get externalCache for key: %s, error: %s", key, err) + log.WithContext(ctx).Debugf("failed to get externalCache for key: %s, error: %s", key, err) } return ud, nil } -func (am *DefaultAccountManager) refreshCache(accountID string) ([]*idp.UserData, error) { - return am.getAccountFromCache(accountID, true) +func (am *DefaultAccountManager) refreshCache(ctx context.Context, accountID string) ([]*idp.UserData, error) { + return am.getAccountFromCache(ctx, accountID, true) } // getAccountFromCache returns user data for a given account ensuring that cache load happens only once -func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceReload bool) ([]*idp.UserData, error) { +func (am *DefaultAccountManager) getAccountFromCache(ctx context.Context, accountID string, forceReload bool) ([]*idp.UserData, error) { am.cacheMux.Lock() loadingChan := am.cacheLoading[accountID] if loadingChan == nil { @@ -1353,7 +1356,7 @@ func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceRelo } am.cacheMux.Unlock() - log.Debugf("one request to get account %s is already running", accountID) + log.WithContext(ctx).Debugf("one request to get account %s is already running", accountID) select { case <-loadingChan: @@ -1364,19 +1367,19 @@ func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceRelo } } -func (am *DefaultAccountManager) lookupCache(accountUsers map[string]userLoggedInOnce, accountID string) ([]*idp.UserData, error) { +func (am *DefaultAccountManager) lookupCache(ctx context.Context, accountUsers map[string]userLoggedInOnce, accountID string) ([]*idp.UserData, error) { var data []*idp.UserData var err error maxAttempts := 2 - data, err = am.getAccountFromCache(accountID, false) + data, err = am.getAccountFromCache(ctx, accountID, false) if err != nil { return nil, err } for attempt := 1; attempt <= maxAttempts; attempt++ { - if am.isCacheFresh(accountUsers, data) { + if am.isCacheFresh(ctx, accountUsers, data) { return data, nil } @@ -1384,14 +1387,14 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]userLoggedI time.Sleep(200 * time.Millisecond) } - log.Infof("refreshing cache for account %s", accountID) - data, err = am.refreshCache(accountID) + log.WithContext(ctx).Infof("refreshing cache for account %s", accountID) + data, err = am.refreshCache(ctx, accountID) if err != nil { return nil, err } if attempt == maxAttempts { - log.Warnf("cache for account %s reached maximum refresh attempts (%d)", accountID, maxAttempts) + log.WithContext(ctx).Warnf("cache for account %s reached maximum refresh attempts (%d)", accountID, maxAttempts) } } @@ -1399,7 +1402,7 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]userLoggedI } // isCacheFresh checks if the cache is refreshed already by comparing the accountUsers with the cache data by user count and user invite status -func (am *DefaultAccountManager) isCacheFresh(accountUsers map[string]userLoggedInOnce, data []*idp.UserData) bool { +func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers map[string]userLoggedInOnce, data []*idp.UserData) bool { userDataMap := make(map[string]*idp.UserData, len(data)) for _, datum := range data { userDataMap[datum.ID] = datum @@ -1412,26 +1415,26 @@ func (am *DefaultAccountManager) isCacheFresh(accountUsers map[string]userLogged if datum, ok := userDataMap[user]; ok { // check if the matching user data has a pending invite and if the user has logged in once, forcing the cache to be refreshed if datum.AppMetadata.WTPendingInvite != nil && *datum.AppMetadata.WTPendingInvite && loggedInOnce == true { //nolint:gosimple - log.Infof("user %s has a pending invite and has logged in once, cache invalid", user) + log.WithContext(ctx).Infof("user %s has a pending invite and has logged in once, cache invalid", user) return false } knownUsersCount-- continue } - log.Debugf("cache doesn't know about %s user", user) + log.WithContext(ctx).Debugf("cache doesn't know about %s user", user) } // if we know users that are not yet in cache more likely cache is outdated if knownUsersCount > 0 { - log.Infof("cache invalid. Users unknown to the cache: %d", knownUsersCount) + log.WithContext(ctx).Infof("cache invalid. Users unknown to the cache: %d", knownUsersCount) return false } return true } -func (am *DefaultAccountManager) removeUserFromCache(accountID, userID string) error { - data, err := am.getAccountFromCache(accountID, false) +func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accountID, userID string) error { + data, err := am.getAccountFromCache(ctx, accountID, false) if err != nil { return err } @@ -1447,7 +1450,7 @@ func (am *DefaultAccountManager) removeUserFromCache(accountID, userID string) e } // updateAccountDomainAttributes updates the account domain attributes and then, saves the account -func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims, +func (am *DefaultAccountManager) updateAccountDomainAttributes(ctx context.Context, account *Account, claims jwtclaims.AuthorizationClaims, primaryDomain bool, ) error { @@ -1464,10 +1467,10 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, account.DomainCategory = claims.DomainCategory } } else { - log.Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims) + log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims) } - err := am.Store.SaveAccount(account) + err := am.Store.SaveAccount(ctx, account) if err != nil { return err } @@ -1476,17 +1479,18 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, // handleExistingUserAccount handles existing User accounts and update its domain attributes. func (am *DefaultAccountManager) handleExistingUserAccount( + ctx context.Context, existingAcc *Account, primaryDomain bool, claims jwtclaims.AuthorizationClaims, ) error { - err := am.updateAccountDomainAttributes(existingAcc, claims, primaryDomain) + err := am.updateAccountDomainAttributes(ctx, existingAcc, claims, primaryDomain) if err != nil { return err } // we should register the account ID to this user's metadata in our IDP manager - err = am.addAccountIDToIDPAppMeta(claims.UserId, existingAcc) + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc) if err != nil { return err } @@ -1496,7 +1500,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount( // handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // otherwise it will create a new account and make it primary account for the domain. -func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) { +func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) { if claims.UserId == "" { return nil, fmt.Errorf("user ID is empty") } @@ -1509,40 +1513,40 @@ func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims if domainAcc != nil { account = domainAcc account.Users[claims.UserId] = NewRegularUser(claims.UserId) - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } } else { - account, err = am.newAccount(claims.UserId, lowerDomain) + account, err = am.newAccount(ctx, claims.UserId, lowerDomain) if err != nil { return nil, err } - err = am.updateAccountDomainAttributes(account, claims, true) + err = am.updateAccountDomainAttributes(ctx, account, claims, true) if err != nil { return nil, err } } - err = am.addAccountIDToIDPAppMeta(claims.UserId, account) + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account) if err != nil { return nil, err } - am.StoreEvent(claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil) + am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil) return account, nil } // redeemInvite checks whether user has been invited and redeems the invite -func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) error { +func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error { // only possible with the enabled IdP manager if am.idpManager == nil { - log.Warnf("invites only work with enabled IdP manager") + log.WithContext(ctx).Warnf("invites only work with enabled IdP manager") return nil } - user, err := am.lookupUserInCache(userID, account) + user, err := am.lookupUserInCache(ctx, userID, account) if err != nil { return err } @@ -1552,17 +1556,17 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e } if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite { - log.Infof("redeeming invite for user %s account %s", userID, account.Id) + log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, account.Id) // User has already logged in, meaning that IdP should have set wt_pending_invite to false. // Our job is to just reload cache. go func() { - _, err = am.refreshCache(account.Id) + _, err = am.refreshCache(ctx, account.Id) if err != nil { - log.Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id) + log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id) return } - log.Debugf("user %s of account %s redeemed invite", user.ID, account.Id) - am.StoreEvent(userID, userID, account.Id, activity.UserJoined, nil) + log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, account.Id) + am.StoreEvent(ctx, userID, userID, account.Id, activity.UserJoined, nil) }() } @@ -1570,22 +1574,22 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e } // MarkPATUsed marks a personal access token as used -func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { +func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error { - user, err := am.Store.GetUserByTokenID(tokenID) + user, err := am.Store.GetUserByTokenID(ctx, tokenID) if err != nil { return err } - account, err := am.Store.GetAccountByUser(user.Id) + account, err := am.Store.GetAccountByUser(ctx, user.Id) if err != nil { return err } - unlock := am.Store.AcquireAccountWriteLock(account.Id) + unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id) defer unlock() - account, err = am.Store.GetAccountByUser(user.Id) + account, err = am.Store.GetAccountByUser(ctx, user.Id) if err != nil { return err } @@ -1597,11 +1601,11 @@ func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { pat.LastUsed = time.Now().UTC() - return am.Store.SaveAccount(account) + return am.Store.SaveAccount(ctx, account) } // GetAccountFromPAT returns Account and User associated with a personal access token -func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, *PersonalAccessToken, error) { +func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) { if len(token) != PATLength { return nil, nil, nil, fmt.Errorf("token has wrong length") } @@ -1625,17 +1629,17 @@ func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *Use hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - tokenID, err := am.Store.GetTokenIDByHashedToken(encodedHashedToken) + tokenID, err := am.Store.GetTokenIDByHashedToken(ctx, encodedHashedToken) if err != nil { return nil, nil, nil, err } - user, err := am.Store.GetUserByTokenID(tokenID) + user, err := am.Store.GetUserByTokenID(ctx, tokenID) if err != nil { return nil, nil, nil, err } - account, err := am.Store.GetAccountByUser(user.Id) + account, err := am.Store.GetAccountByUser(ctx, user.Id) if err != nil { return nil, nil, nil, err } @@ -1649,7 +1653,7 @@ func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *Use } // GetAccountFromToken returns an account associated with this token -func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) { +func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) { if claims.UserId == "" { return nil, nil, fmt.Errorf("user ID is empty") } @@ -1658,14 +1662,14 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat // We override incoming domain claims to group users under a single account. claims.Domain = am.singleAccountModeDomain claims.DomainCategory = PrivateCategory - log.Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") + log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") } - newAcc, err := am.getAccountWithAuthorizationClaims(claims) + newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims) if err != nil { return nil, nil, err } - unlock := am.Store.AcquireAccountWriteLock(newAcc.Id) + unlock := am.Store.AcquireAccountWriteLock(ctx, newAcc.Id) alreadyUnlocked := false defer func() { if !alreadyUnlocked { @@ -1673,7 +1677,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat } }() - account, err := am.Store.GetAccount(newAcc.Id) + account, err := am.Store.GetAccount(ctx, newAcc.Id) if err != nil { return nil, nil, err } @@ -1685,7 +1689,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat } if !user.IsServiceUser && claims.Invited { - err = am.redeemInvite(account, claims.UserId) + err = am.redeemInvite(ctx, account, claims.UserId) if err != nil { return nil, nil, err } @@ -1693,7 +1697,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat if account.Settings.JWTGroupsEnabled { if account.Settings.JWTGroupsClaimName == "" { - log.Errorf("JWT groups are enabled but no claim name is set") + log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") return account, user, nil } if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok { @@ -1703,7 +1707,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat if g, ok := item.(string); ok { groupsNames = append(groupsNames, g) } else { - log.Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item) + log.WithContext(ctx).Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item) } } @@ -1718,16 +1722,16 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) account.Network.IncSerial() - if err := am.Store.SaveAccount(account); err != nil { - log.Errorf("failed to save account: %v", err) + if err := am.Store.SaveAccount(ctx, account); err != nil { + log.WithContext(ctx).Errorf("failed to save account: %v", err) } else { - log.Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(account) + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) + am.updateAccountPeers(ctx, account) unlock() alreadyUnlocked = true for _, g := range addNewGroups { if group := account.GetGroup(g); group != nil { - am.StoreEvent(user.Id, user.Id, account.Id, activity.GroupAddedToUser, + am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, map[string]any{ "group": group.Name, "group_id": group.ID, @@ -1737,7 +1741,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat } for _, g := range removeOldGroups { if group := account.GetGroup(g); group != nil { - am.StoreEvent(user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, + am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, map[string]any{ "group": group.Name, "group_id": group.ID, @@ -1748,16 +1752,16 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat } } } else { - if err := am.Store.SaveAccount(account); err != nil { - log.Errorf("failed to save account: %v", err) + if err := am.Store.SaveAccount(ctx, account); err != nil { + log.WithContext(ctx).Errorf("failed to save account: %v", err) } } } } else { - log.Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName) + log.WithContext(ctx).Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName) } } else { - log.Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName) + log.WithContext(ctx).Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName) } } @@ -1781,8 +1785,8 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat // Existing user + Existing account + Existing Indexed Domain -> Nothing changes // // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) -func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) { - log.Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", +func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) { + log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) if claims.UserId == "" { return nil, fmt.Errorf("user ID is empty") @@ -1790,9 +1794,9 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla // if Account ID is part of the claims // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - return am.GetAccountByUserOrAccountID(claims.UserId, claims.AccountId, claims.Domain) + return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) } else if claims.AccountId != "" { - accountFromID, err := am.Store.GetAccount(claims.AccountId) + accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId) if err != nil { return nil, err } @@ -1805,12 +1809,12 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla } start := time.Now() - unlock := am.Store.AcquireGlobalLock() + unlock := am.Store.AcquireGlobalLock(ctx) defer unlock() - log.Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId) + log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId) // We checked if the domain has a primary account already - domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain) + domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) if err != nil { // if NotFound we are good to continue, otherwise return error e, ok := status.FromError(err) @@ -1819,11 +1823,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla } } - account, err := am.Store.GetAccountByUser(claims.UserId) + account, err := am.Store.GetAccountByUser(ctx, claims.UserId) if err == nil { - unlockAccount := am.Store.AcquireAccountWriteLock(account.Id) + unlockAccount := am.Store.AcquireAccountWriteLock(ctx, account.Id) defer unlockAccount() - account, err = am.Store.GetAccountByUser(claims.UserId) + account, err = am.Store.GetAccountByUser(ctx, claims.UserId) if err != nil { return nil, err } @@ -1834,29 +1838,29 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla // and peers that shouldn't be lost. primaryDomain := domainAccount == nil || account.Id == domainAccount.Id - err = am.handleExistingUserAccount(account, primaryDomain, claims) + err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims) if err != nil { return nil, err } return account, nil } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { if domainAccount != nil { - unlockAccount := am.Store.AcquireAccountWriteLock(domainAccount.Id) + unlockAccount := am.Store.AcquireAccountWriteLock(ctx, domainAccount.Id) defer unlockAccount() - domainAccount, err = am.Store.GetAccountByPrivateDomain(claims.Domain) + domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) if err != nil { return nil, err } } - return am.handleNewUserAccount(domainAccount, claims) + return am.handleNewUserAccount(ctx, domainAccount, claims) } else { // other error return nil, err } } -func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { - accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey) +func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { + accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey) if err != nil { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { return nil, nil, nil, status.Errorf(status.Unauthenticated, "peer not registered") @@ -1864,29 +1868,29 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer. return nil, nil, nil, err } - unlock := am.Store.AcquireAccountReadLock(accountID) + unlock := am.Store.AcquireAccountReadLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, nil, nil, err } - peer, netMap, postureChecks, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account) + peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account) if err != nil { return nil, nil, nil, err } - err = am.MarkPeerConnected(peerPubKey, true, realIP, account) + err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account) if err != nil { - log.Warnf("failed marking peer as connected %s %v", peerPubKey, err) + log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } return peer, netMap, postureChecks, nil } -func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error { - accountID, err := am.Store.GetAccountIDByPeerPubKey(peer.Key) +func (am *DefaultAccountManager) CancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) error { + accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peer.Key) if err != nil { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { return status.Errorf(status.Unauthenticated, "peer not registered") @@ -1894,40 +1898,40 @@ func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error { return err } - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } - err = am.MarkPeerConnected(peer.Key, false, nil, account) + err = am.MarkPeerConnected(ctx, peer.Key, false, nil, account) if err != nil { - log.Warnf("failed marking peer as connected %s %v", peer.Key, err) + log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peer.Key, err) } return nil } -func (am *DefaultAccountManager) SyncPeerMeta(peerPubKey string, meta nbpeer.PeerSystemMeta) error { - accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey) +func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error { + accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey) if err != nil { return err } - unlock := am.Store.AcquireAccountReadLock(accountID) + unlock := am.Store.AcquireAccountReadLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } - _, _, _, err = am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, account) + _, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, account) if err != nil { - return mapError(err) + return mapError(ctx, err) } return nil } @@ -1955,8 +1959,8 @@ func (am *DefaultAccountManager) GetDNSDomain() string { // CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT // group propagation and set the list of groups with access permissions. -func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error { - account, _, err := am.GetAccountFromToken(claims) +func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { + account, _, err := am.GetAccountFromToken(ctx, claims) if err != nil { return err } @@ -1986,20 +1990,24 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut return nil } -func (am *DefaultAccountManager) onPeersInvalidated(accountID string) { - log.Debugf("validated peers has been invalidated for account %s", accountID) - updatedAccount, err := am.Store.GetAccount(accountID) +func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { + log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID) + updatedAccount, err := am.Store.GetAccount(ctx, accountID) if err != nil { - log.Errorf("failed to get account %s: %v", accountID, err) + log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err) return } - am.updateAccountPeers(updatedAccount) + am.updateAccountPeers(ctx, updatedAccount) } func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { return am.Store.GetPostureCheckByChecksDefinition(accountID, checks) } +func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) { + return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey) +} + // addAllGroup to account object if it doesn't exist func addAllGroup(account *Account) error { if len(account.Groups) == 0 { @@ -2041,8 +2049,8 @@ func addAllGroup(account *Account) error { } // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id -func newAccountWithId(accountID, userID, domain string) *Account { - log.Debugf("creating new account") +func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Account { + log.WithContext(ctx).Debugf("creating new account") network := NewNetwork() peers := make(map[string]*nbpeer.Peer) @@ -2054,7 +2062,7 @@ func newAccountWithId(accountID, userID, domain string) *Account { dnsSettings := DNSSettings{ DisabledManagementGroups: make([]string, 0), } - log.Debugf("created new account %s", accountID) + log.WithContext(ctx).Debugf("created new account %s", accountID) acc := &Account{ Id: accountID, @@ -2077,7 +2085,7 @@ func newAccountWithId(accountID, userID, domain string) *Account { } if err := addAllGroup(acc); err != nil { - log.Errorf("error adding all group to account %s: %v", acc.Id, err) + log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err) } return acc } diff --git a/management/server/account_test.go b/management/server/account_test.go index eaadb5633..71b43bd65 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "crypto/sha256" b64 "encoding/base64" "encoding/json" @@ -29,11 +30,11 @@ import ( type MocIntegratedValidator struct { } -func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { +func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { return nil } -func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { +func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { return update, nil } func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { @@ -44,15 +45,15 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[s return validatedPeers, nil } -func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { +func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { return peer } -func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { +func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { return false, false, nil } -func (MocIntegratedValidator) PeerDeleted(_, _ string) error { +func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { return nil } @@ -60,7 +61,7 @@ func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string) } -func (MocIntegratedValidator) Stop() { +func (MocIntegratedValidator) Stop(_ context.Context) { } func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) { @@ -85,7 +86,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Ac setupKey = key.Key } - _, _, _, err := manager.AddPeer(setupKey, userID, peer) + _, _, _, err := manager.AddPeer(context.Background(), setupKey, userID, peer) if err != nil { t.Error("expected to add new peer successfully after creating new account, but failed", err) } @@ -395,7 +396,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } for _, testCase := range tt { - account := newAccountWithId("account-1", userID, "netbird.io") + account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io") account.UpdateSettings(&testCase.accountSettings) account.Network = network account.Peers = testCase.peers @@ -409,7 +410,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { validatedPeers[p] = struct{}{} } - networkMap := account.GetPeerNetworkMap(testCase.peerID, "netbird.io", validatedPeers) + networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, "netbird.io", validatedPeers) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } @@ -419,7 +420,7 @@ func TestNewAccount(t *testing.T) { domain := "netbird.io" userId := "account_creator" accountID := "account_id" - account := newAccountWithId(accountID, userId, domain) + account := newAccountWithId(context.Background(), accountID, userId, domain) verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId}) } @@ -430,7 +431,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { return } - account, err := manager.GetOrCreateAccountByUser(userID, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") if err != nil { t.Fatal(err) } @@ -439,7 +440,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { return } - account, err = manager.Store.GetAccountByUser(userID) + account, err = manager.Store.GetAccountByUser(context.Background(), userID) if err != nil { t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID) return @@ -630,11 +631,11 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - initAccount, err := manager.GetAccountByUserOrAccountID(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) + initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) require.NoError(t, err, "create init user failed") if testCase.inputUpdateAttrs { - err = manager.updateAccountDomainAttributes(initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) + err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) require.NoError(t, err, "update init user failed") } @@ -642,7 +643,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { testCase.inputClaims.AccountId = initAccount.Id } - account, _, err := manager.GetAccountFromToken(testCase.inputClaims) + account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims) require.NoError(t, err, "support function failed") verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers) verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy) @@ -661,12 +662,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { userId := "user-id" domain := "test.domain" - initAccount := newAccountWithId("", userId, domain) + initAccount := newAccountWithId(context.Background(), "", userId, domain) manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID := initAccount.Id - acc, err := manager.GetAccountByUserOrAccountID(userId, accountID, domain) + acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain) require.NoError(t, err, "create init user failed") // as initAccount was created without account id we have to take the id after account initialization // that happens inside the GetAccountByUserOrAccountID where the id is getting generated @@ -682,18 +683,18 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { } t.Run("JWT groups disabled", func(t *testing.T) { - account, _, err := manager.GetAccountFromToken(claims) + account, _, err := manager.GetAccountFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") require.Len(t, account.Groups, 1, "only ALL group should exists") }) t.Run("JWT groups enabled without claim name", func(t *testing.T) { initAccount.Settings.JWTGroupsEnabled = true - err := manager.Store.SaveAccount(initAccount) + err := manager.Store.SaveAccount(context.Background(), initAccount) require.NoError(t, err, "save account failed") - require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist") + require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") - account, _, err := manager.GetAccountFromToken(claims) + account, _, err := manager.GetAccountFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") }) @@ -701,11 +702,11 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { t.Run("JWT groups enabled", func(t *testing.T) { initAccount.Settings.JWTGroupsEnabled = true initAccount.Settings.JWTGroupsClaimName = "idp-groups" - err := manager.Store.SaveAccount(initAccount) + err := manager.Store.SaveAccount(context.Background(), initAccount) require.NoError(t, err, "save account failed") - require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist") + require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") - account, _, err := manager.GetAccountFromToken(claims) + account, _, err := manager.GetAccountFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") require.Len(t, account.Groups, 3, "groups should be added to the account") @@ -728,7 +729,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { func TestAccountManager_GetAccountFromPAT(t *testing.T) { store := newStore(t) - account := newAccountWithId("account_id", "testuser", "") + account := newAccountWithId(context.Background(), "account_id", "testuser", "") token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) @@ -742,7 +743,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { }, }, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -751,7 +752,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { Store: store, } - account, user, pat, err := am.GetAccountFromPAT(token) + account, user, pat, err := am.GetAccountFromPAT(context.Background(), token) if err != nil { t.Fatalf("Error when getting Account from PAT: %s", err) } @@ -763,7 +764,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { store := newStore(t) - account := newAccountWithId("account_id", "testuser", "") + account := newAccountWithId(context.Background(), "account_id", "testuser", "") token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) @@ -778,7 +779,7 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { }, }, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -787,12 +788,12 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { Store: store, } - err = am.MarkPATUsed("tokenId") + err = am.MarkPATUsed(context.Background(), "tokenId") if err != nil { t.Fatalf("Error when marking PAT used: %s", err) } - account, err = am.Store.GetAccount("account_id") + account, err = am.Store.GetAccount(context.Background(), "account_id") if err != nil { t.Fatalf("Error when getting account: %s", err) } @@ -807,7 +808,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) { } userId := "test_user" - account, err := manager.GetOrCreateAccountByUser(userId, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "") if err != nil { t.Fatal(err) } @@ -815,7 +816,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) { t.Fatalf("expected to create an account for a user %s", userId) } - account, err = manager.Store.GetAccountByUser(userId) + account, err = manager.Store.GetAccountByUser(context.Background(), userId) if err != nil { t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId) } @@ -834,7 +835,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { userId := "test_user" domain := "hotmail.com" - account, err := manager.GetOrCreateAccountByUser(userId, domain) + account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain) if err != nil { t.Fatal(err) } @@ -848,7 +849,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { domain = "gmail.com" - account, err = manager.GetOrCreateAccountByUser(userId, domain) + account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain) if err != nil { t.Fatalf("got the following error while retrieving existing acc: %v", err) } @@ -871,7 +872,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { userId := "test_user" - account, err := manager.GetAccountByUserOrAccountID(userId, "", "") + account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "") if err != nil { t.Fatal(err) } @@ -880,20 +881,20 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { return } - _, err = manager.GetAccountByUserOrAccountID("", account.Id, "") + _, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") if err != nil { t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id) } - _, err = manager.GetAccountByUserOrAccountID("", "", "") + _, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "") if err == nil { t.Errorf("expected an error when user and account IDs are empty") } } func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) { - account := newAccountWithId(accountID, userID, domain) - err := am.Store.SaveAccount(account) + account := newAccountWithId(context.Background(), accountID, userID, domain) + err := am.Store.SaveAccount(context.Background(), account) if err != nil { return nil, err } @@ -915,7 +916,7 @@ func TestAccountManager_GetAccount(t *testing.T) { } // AddAccount has been already tested so we can assume it is correct and compare results - getAccount, err := manager.Store.GetAccount(account.Id) + getAccount, err := manager.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Fatal(err) return @@ -952,12 +953,12 @@ func TestAccountManager_DeleteAccount(t *testing.T) { t.Fatal(err) } - err = manager.DeleteAccount(account.Id, userId) + err = manager.DeleteAccount(context.Background(), account.Id, userId) if err != nil { t.Fatal(err) } - getAccount, err := manager.Store.GetAccount(account.Id) + getAccount, err := manager.Store.GetAccount(context.Background(), account.Id) if err == nil { t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount)) } @@ -978,7 +979,7 @@ func TestAccountManager_AddPeer(t *testing.T) { serial := account.Network.CurrentSerial() // should be 0 - setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { t.Fatal("error creating setup key") return @@ -997,7 +998,7 @@ func TestAccountManager_AddPeer(t *testing.T) { expectedPeerKey := key.PublicKey().String() expectedSetupKey := setupKey.Key - peer, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, }) @@ -1006,7 +1007,7 @@ func TestAccountManager_AddPeer(t *testing.T) { return } - account, err = manager.Store.GetAccount(account.Id) + account, err = manager.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Fatal(err) return @@ -1045,7 +1046,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { return } - account, err := manager.GetOrCreateAccountByUser(userID, "netbird.cloud") + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud") if err != nil { t.Fatal(err) } @@ -1065,7 +1066,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { expectedPeerKey := key.PublicKey().String() expectedUserID := userID - peer, _, _, err := manager.AddPeer("", userID, &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, }) @@ -1074,7 +1075,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { return } - account, err = manager.Store.GetAccount(account.Id) + account, err = manager.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Fatal(err) return @@ -1121,7 +1122,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { t.Fatal("error creating setup key") return @@ -1140,7 +1141,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { } expectedPeerKey := key.PublicKey().String() - peer, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, }) @@ -1156,14 +1157,14 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { peer2 := getPeer() peer3 := getPeer() - account, err = manager.Store.GetAccount(account.Id) + account, err = manager.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Fatal(err) return } - updMsg := manager.peersUpdateManager.CreateChannel(peer1.ID) - defer manager.peersUpdateManager.CloseChannel(peer1.ID) + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) group := group.Group{ ID: "group-id", @@ -1197,7 +1198,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { } }() - if err := manager.SaveGroup(account.Id, userID, &group); err != nil { + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { t.Errorf("save group: %v", err) return } @@ -1217,7 +1218,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { } }() - if err := manager.DeletePolicy(account.Id, account.Policies[0].ID, userID); err != nil { + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { t.Errorf("delete default rule: %v", err) return } @@ -1237,7 +1238,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { } }() - if err := manager.SavePolicy(account.Id, userID, &policy); err != nil { + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil { t.Errorf("delete default rule: %v", err) return } @@ -1256,7 +1257,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { } }() - if err := manager.DeletePeer(account.Id, peer3.ID, userID); err != nil { + if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil { t.Errorf("delete peer: %v", err) return } @@ -1277,9 +1278,9 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { }() // clean policy is pre requirement for delete group - _ = manager.DeletePolicy(account.Id, policy.ID, userID) + _ = manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID) - if err := manager.DeleteGroup(account.Id, "", group.ID); err != nil { + if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { t.Errorf("delete group: %v", err) return } @@ -1301,7 +1302,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { t.Fatal("error creating setup key") return @@ -1315,7 +1316,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { peerKey := key.PublicKey().String() - peer, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: peerKey, Meta: nbpeer.PeerSystemMeta{Hostname: peerKey}, }) @@ -1324,12 +1325,12 @@ func TestAccountManager_DeletePeer(t *testing.T) { return } - err = manager.DeletePeer(account.Id, peerKey, userID) + err = manager.DeletePeer(context.Background(), account.Id, peerKey, userID) if err != nil { return } - account, err = manager.Store.GetAccount(account.Id) + account, err = manager.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Fatal(err) return @@ -1357,7 +1358,7 @@ func getEvent(t *testing.T, accountID string, manager AccountManager, eventType case <-time.After(time.Second): t.Fatal("no PeerAddedWithSetupKey event was generated") default: - events, err := manager.GetEvents(accountID, userID) + events, err := manager.GetEvents(context.Background(), accountID, userID) if err != nil { t.Fatal(err) } @@ -1389,7 +1390,7 @@ func TestGetUsersFromAccount(t *testing.T) { account.Users[user.Id] = user } - userInfos, err := manager.GetUsersFromAccount(accountId, "1") + userInfos, err := manager.GetUsersFromAccount(context.Background(), accountId, "1") if err != nil { t.Fatal(err) } @@ -1500,7 +1501,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) { }, } - routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}) + routes := account.getRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}) assert.Len(t, routes, 2) routeIDs := make(map[route.ID]struct{}, 2) @@ -1510,7 +1511,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) { assert.Contains(t, routeIDs, route.ID("route-2")) assert.Contains(t, routeIDs, route.ID("route-3")) - emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}) + emptyRoutes := account.getRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}) assert.Len(t, emptyRoutes, 0) } @@ -1645,7 +1646,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") assert.NotNil(t, account.Settings) @@ -1657,23 +1658,23 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountByUserOrAccountID(userID, "", "") + _, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - peer, _, _, err := manager.AddPeer("", userID, &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") - account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1682,10 +1683,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { wg := &sync.WaitGroup{} wg.Add(2) manager.peerLoginExpiry = &MockScheduler{ - CancelFunc: func(IDs []string) { + CancelFunc: func(ctx context.Context, IDs []string) { wg.Done() }, - ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { + ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { wg.Done() }, } @@ -1693,11 +1694,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { // disable expiration first update := peer.Copy() update.LoginExpirationEnabled = false - _, err = manager.UpdatePeer(account.Id, userID, update) + _, err = manager.UpdatePeer(context.Background(), account.Id, userID, update) require.NoError(t, err, "unable to update peer") // enabling expiration should trigger the routine update.LoginExpirationEnabled = true - _, err = manager.UpdatePeer(account.Id, userID, update) + _, err = manager.UpdatePeer(context.Background(), account.Id, userID, update) require.NoError(t, err, "unable to update peer") failed := waitTimeout(wg, time.Second) @@ -1710,18 +1711,18 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - _, _, _, err = manager.AddPeer("", userID, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - _, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1730,18 +1731,18 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. wg := &sync.WaitGroup{} wg.Add(2) manager.peerLoginExpiry = &MockScheduler{ - CancelFunc: func(IDs []string) { + CancelFunc: func(ctx context.Context, IDs []string) { wg.Done() }, - ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { + ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { wg.Done() }, } - account, err = manager.GetAccountByUserOrAccountID(userID, "", "") + account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to get the account") // when we mark peer as connected, the peer login expiration routine should trigger - err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") failed := waitTimeout(wg, time.Second) @@ -1754,35 +1755,35 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountByUserOrAccountID(userID, "", "") + _, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - _, _, _, err = manager.AddPeer("", userID, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") wg := &sync.WaitGroup{} wg.Add(2) manager.peerLoginExpiry = &MockScheduler{ - CancelFunc: func(IDs []string) { + CancelFunc: func(ctx context.Context, IDs []string) { wg.Done() }, - ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { + ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { wg.Done() }, } // enabling PeerLoginExpirationEnabled should trigger the expiration job - account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1795,7 +1796,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test wg.Add(1) // disabling PeerLoginExpirationEnabled should trigger cancel - _, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, }) @@ -1810,10 +1811,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") - updated, err := manager.UpdateAccountSettings(account.Id, userID, &Settings{ + updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, }) @@ -1821,19 +1822,19 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) - account, err = manager.GetAccountByUserOrAccountID("", account.Id, "") + account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") require.NoError(t, err, "unable to get account by ID") assert.False(t, account.Settings.PeerLoginExpirationEnabled) assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour) - _, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ PeerLoginExpiration: time.Second, PeerLoginExpirationEnabled: false, }) require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") - _, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ PeerLoginExpiration: time.Hour * 24 * 181, PeerLoginExpirationEnabled: false, }) @@ -2294,7 +2295,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) { } eventStore := &activity.InMemoryEventStore{} - manager, err := BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}) + manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}) if err != nil { return nil, err } @@ -2305,7 +2306,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) { func createStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(dataDir) + store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) if err != nil { return nil, err } diff --git a/management/server/activity/sqlite/sqlite.go b/management/server/activity/sqlite/sqlite.go index b54db5276..fadf1eb07 100644 --- a/management/server/activity/sqlite/sqlite.go +++ b/management/server/activity/sqlite/sqlite.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "database/sql" "encoding/json" "fmt" @@ -86,7 +87,7 @@ type Store struct { } // NewSQLiteStore creates a new Store with an event table if not exists. -func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) { +func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (*Store, error) { dbFile := filepath.Join(dataDir, eventSinkDB) db, err := sql.Open("sqlite3", dbFile) if err != nil { @@ -111,7 +112,7 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) { return nil, err } - err = updateDeletedUsersTable(db) + err = updateDeletedUsersTable(ctx, db) if err != nil { _ = db.Close() return nil, err @@ -153,7 +154,7 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) { return s, nil } -func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) { +func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) { events := make([]*activity.Event, 0) var cryptErr error for result.Next() { @@ -235,14 +236,14 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) { } if cryptErr != nil { - log.Warnf("%s", cryptErr) + log.WithContext(ctx).Warnf("%s", cryptErr) } return events, nil } // Get returns "limit" number of events from index ordered descending or ascending by a timestamp -func (store *Store) Get(accountID string, offset, limit int, descending bool) ([]*activity.Event, error) { +func (store *Store) Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*activity.Event, error) { stmt := store.selectDescStatement if !descending { stmt = store.selectAscStatement @@ -254,11 +255,11 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([ } defer result.Close() //nolint - return store.processResult(result) + return store.processResult(ctx, result) } // Save an event in the SQLite events table end encrypt the "email" element in meta map -func (store *Store) Save(event *activity.Event) (*activity.Event, error) { +func (store *Store) Save(_ context.Context, event *activity.Event) (*activity.Event, error) { var jsonMeta string meta, err := store.saveDeletedUserEmailAndNameInEncrypted(event) if err != nil { @@ -317,15 +318,15 @@ func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event } // Close the Store -func (store *Store) Close() error { +func (store *Store) Close(_ context.Context) error { if store.db != nil { return store.db.Close() } return nil } -func updateDeletedUsersTable(db *sql.DB) error { - log.Debugf("check deleted_users table version") +func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error { + log.WithContext(ctx).Debugf("check deleted_users table version") rows, err := db.Query(`PRAGMA table_info(deleted_users);`) if err != nil { return err @@ -360,7 +361,7 @@ func updateDeletedUsersTable(db *sql.DB) error { return nil } - log.Debugf("update delted_users table") + log.WithContext(ctx).Debugf("update delted_users table") _, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`) return err } diff --git a/management/server/activity/sqlite/sqlite_test.go b/management/server/activity/sqlite/sqlite_test.go index f6a6f9467..b10f9b58a 100644 --- a/management/server/activity/sqlite/sqlite_test.go +++ b/management/server/activity/sqlite/sqlite_test.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "fmt" "testing" "time" @@ -13,17 +14,17 @@ import ( func TestNewSQLiteStore(t *testing.T) { dataDir := t.TempDir() key, _ := GenerateKey() - store, err := NewSQLiteStore(dataDir, key) + store, err := NewSQLiteStore(context.Background(), dataDir, key) if err != nil { t.Fatal(err) return } - defer store.Close() //nolint + defer store.Close(context.Background()) //nolint accountID := "account_1" for i := 0; i < 10; i++ { - _, err = store.Save(&activity.Event{ + _, err = store.Save(context.Background(), &activity.Event{ Timestamp: time.Now().UTC(), Activity: activity.PeerAddedByUser, InitiatorID: "user_" + fmt.Sprint(i), @@ -36,7 +37,7 @@ func TestNewSQLiteStore(t *testing.T) { } } - result, err := store.Get(accountID, 0, 10, false) + result, err := store.Get(context.Background(), accountID, 0, 10, false) if err != nil { t.Fatal(err) return @@ -45,7 +46,7 @@ func TestNewSQLiteStore(t *testing.T) { assert.Len(t, result, 10) assert.True(t, result[0].Timestamp.Before(result[len(result)-1].Timestamp)) - result, err = store.Get(accountID, 0, 5, true) + result, err = store.Get(context.Background(), accountID, 0, 5, true) if err != nil { t.Fatal(err) return diff --git a/management/server/activity/store.go b/management/server/activity/store.go index 77439e2e1..ef08e2b33 100644 --- a/management/server/activity/store.go +++ b/management/server/activity/store.go @@ -1,15 +1,18 @@ package activity -import "sync" +import ( + "context" + "sync" +) // Store provides an interface to store or stream events. type Store interface { // Save an event in the store - Save(event *Event) (*Event, error) + Save(ctx context.Context, event *Event) (*Event, error) // Get returns "limit" number of events from the "offset" index ordered descending or ascending by a timestamp - Get(accountID string, offset, limit int, descending bool) ([]*Event, error) + Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*Event, error) // Close the sink flushing events if necessary - Close() error + Close(ctx context.Context) error } // InMemoryEventStore implements the Store interface storing data in-memory @@ -20,7 +23,7 @@ type InMemoryEventStore struct { } // Save sets the Event.ID to 1 -func (store *InMemoryEventStore) Save(event *Event) (*Event, error) { +func (store *InMemoryEventStore) Save(_ context.Context, event *Event) (*Event, error) { store.mu.Lock() defer store.mu.Unlock() if store.events == nil { @@ -33,7 +36,7 @@ func (store *InMemoryEventStore) Save(event *Event) (*Event, error) { } // Get returns a list of ALL events that belong to the given accountID without taking offset, limit and order into consideration -func (store *InMemoryEventStore) Get(accountID string, offset, limit int, descending bool) ([]*Event, error) { +func (store *InMemoryEventStore) Get(_ context.Context, accountID string, offset, limit int, descending bool) ([]*Event, error) { store.mu.Lock() defer store.mu.Unlock() events := make([]*Event, 0) @@ -46,7 +49,7 @@ func (store *InMemoryEventStore) Get(accountID string, offset, limit int, descen } // Close cleans up the event list -func (store *InMemoryEventStore) Close() error { +func (store *InMemoryEventStore) Close(_ context.Context) error { store.mu.Lock() defer store.mu.Unlock() store.events = make([]*Event, 0) diff --git a/management/server/context/keys.go b/management/server/context/keys.go new file mode 100644 index 000000000..c5b5da044 --- /dev/null +++ b/management/server/context/keys.go @@ -0,0 +1,8 @@ +package context + +const ( + RequestIDKey = "requestID" + AccountIDKey = "accountID" + UserIDKey = "userID" + PeerIDKey = "peerID" +) diff --git a/management/server/dns.go b/management/server/dns.go index 5e2febf55..8a889df3f 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "strconv" @@ -34,11 +35,11 @@ func (d DNSSettings) Copy() DNSSettings { } // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID -func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -56,11 +57,11 @@ func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) } // SaveDNSSettings validates a user role and updates the account's DNS settings -func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -89,7 +90,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string account.DNSSettings = dnsSettingsToSave.Copy() account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } @@ -97,17 +98,17 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string for _, id := range addedGroups { group := account.GetGroup(id) meta := map[string]any{"group": group.Name, "group_id": group.ID} - am.StoreEvent(userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) + am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) } removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) for _, id := range removedGroups { group := account.GetGroup(id) meta := map[string]any{"group": group.Name, "group_id": group.ID} - am.StoreEvent(userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) + am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } @@ -149,9 +150,9 @@ func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig { return protoUpdate } -func getPeersCustomZone(account *Account, dnsDomain string) nbdns.CustomZone { +func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string) nbdns.CustomZone { if dnsDomain == "" { - log.Errorf("no dns domain is set, returning empty zone") + log.WithContext(ctx).Errorf("no dns domain is set, returning empty zone") return nbdns.CustomZone{} } @@ -161,7 +162,7 @@ func getPeersCustomZone(account *Account, dnsDomain string) nbdns.CustomZone { for _, peer := range account.Peers { if peer.DNSLabel == "" { - log.Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name) + log.WithContext(ctx).Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name) continue } @@ -210,14 +211,14 @@ func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { return false } -func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) { +func addPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels lookupMap) { for _, peer := range account.Peers { label, err := getPeerHostLabel(peer.Name, peerLabels) if err != nil { - log.Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err) + log.WithContext(ctx).Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err) label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels) if err != nil { - log.Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err) + log.WithContext(ctx).Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err) continue } } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index a53789526..c6758036f 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "net/netip" "testing" @@ -35,7 +36,7 @@ func TestGetDNSSettings(t *testing.T) { t.Fatal("failed to init testing account") } - dnsSettings, err := am.GetDNSSettings(account.Id, dnsAdminUserID) + dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID) if err != nil { t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) } @@ -48,12 +49,12 @@ func TestGetDNSSettings(t *testing.T) { DisabledManagementGroups: []string{group1ID}, } - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(context.Background(), account) if err != nil { t.Error("failed to save testing account with new DNS settings") } - dnsSettings, err = am.GetDNSSettings(account.Id, dnsAdminUserID) + dnsSettings, err = am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID) if err != nil { t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) } @@ -62,7 +63,7 @@ func TestGetDNSSettings(t *testing.T) { t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups) } - _, err = am.GetDNSSettings(account.Id, dnsRegularUserID) + _, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID) if err == nil { t.Errorf("An error should be returned when getting the DNS settings with a regular user") } @@ -122,7 +123,7 @@ func TestSaveDNSSettings(t *testing.T) { t.Error("failed to init testing account") } - err = am.SaveDNSSettings(account.Id, testCase.userID, testCase.inputSettings) + err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings) if err != nil { if testCase.shouldFail { return @@ -130,7 +131,7 @@ func TestSaveDNSSettings(t *testing.T) { t.Error(err) } - updatedAccount, err := am.Store.GetAccount(account.Id) + updatedAccount, err := am.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Errorf("should be able to retrieve updated account, got err: %s", err) } @@ -164,7 +165,7 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) { t.Error("failed to init testing account") } - newAccountDNSConfig, err := am.GetNetworkMap(peer1.ID) + newAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID) require.NoError(t, err) require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers") require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled") @@ -173,14 +174,14 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) { dnsSettings := account.DNSSettings.Copy() dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID) account.DNSSettings = dnsSettings - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(context.Background(), account) require.NoError(t, err) - updatedAccountDNSConfig, err := am.GetNetworkMap(peer1.ID) + updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID) require.NoError(t, err) require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group") require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group") - peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID) + peer2AccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer2.ID) require.NoError(t, err) require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group") require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group") @@ -194,13 +195,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}) + return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}) } func createDNSStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(dataDir) + store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) if err != nil { return nil, err } @@ -244,28 +245,28 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro domain := "example.com" - account := newAccountWithId(dnsAccountID, dnsAdminUserID, domain) + account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain) account.Users[dnsRegularUserID] = &User{ Id: dnsRegularUserID, Role: UserRoleUser, } - err := am.Store.SaveAccount(account) + err := am.Store.SaveAccount(context.Background(), account) if err != nil { return nil, err } - savedPeer1, _, _, err := am.AddPeer("", dnsAdminUserID, peer1) + savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1) if err != nil { return nil, err } - _, _, _, err = am.AddPeer("", dnsAdminUserID, peer2) + _, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2) if err != nil { return nil, err } - account, err = am.Store.GetAccount(account.Id) + account, err = am.Store.GetAccount(context.Background(), account.Id) if err != nil { return nil, err } @@ -312,10 +313,10 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro Groups: []string{allGroup.ID}, } - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(context.Background(), account) if err != nil { return nil, err } - return am.Store.GetAccount(account.Id) + return am.Store.GetAccount(context.Background(), account.Id) } diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 4fffa024d..590b1d708 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -1,6 +1,7 @@ package server import ( + "context" "sync" "time" @@ -51,13 +52,15 @@ func NewEphemeralManager(store Store, accountManager AccountManager) *EphemeralM // LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head // of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new // head. -func (e *EphemeralManager) LoadInitialPeers() { +func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) { e.peersLock.Lock() defer e.peersLock.Unlock() - e.loadEphemeralPeers() + e.loadEphemeralPeers(ctx) if e.headPeer != nil { - e.timer = time.AfterFunc(ephemeralLifeTime, e.cleanup) + e.timer = time.AfterFunc(ephemeralLifeTime, func() { + e.cleanup(ctx) + }) } } @@ -73,12 +76,12 @@ func (e *EphemeralManager) Stop() { // OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer // is active the manager will not delete it while it is active. -func (e *EphemeralManager) OnPeerConnected(peer *nbpeer.Peer) { +func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Peer) { if !peer.Ephemeral { return } - log.Tracef("remove peer from ephemeral list: %s", peer.ID) + log.WithContext(ctx).Tracef("remove peer from ephemeral list: %s", peer.ID) e.peersLock.Lock() defer e.peersLock.Unlock() @@ -94,16 +97,16 @@ func (e *EphemeralManager) OnPeerConnected(peer *nbpeer.Peer) { // OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer // is inactive it will be deleted after the ephemeralLifeTime period. -func (e *EphemeralManager) OnPeerDisconnected(peer *nbpeer.Peer) { +func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) { if !peer.Ephemeral { return } - log.Tracef("add peer to ephemeral list: %s", peer.ID) + log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID) - a, err := e.store.GetAccountByPeerID(peer.ID) + a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID) if err != nil { - log.Errorf("failed to add peer to ephemeral list: %s", err) + log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err) return } @@ -116,12 +119,14 @@ func (e *EphemeralManager) OnPeerDisconnected(peer *nbpeer.Peer) { e.addPeer(peer.ID, a, newDeadLine()) if e.timer == nil { - e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup) + e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() { + e.cleanup(ctx) + }) } } -func (e *EphemeralManager) loadEphemeralPeers() { - accounts := e.store.GetAllAccounts() +func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { + accounts := e.store.GetAllAccounts(context.Background()) t := newDeadLine() count := 0 for _, a := range accounts { @@ -132,10 +137,10 @@ func (e *EphemeralManager) loadEphemeralPeers() { } } } - log.Debugf("loaded ephemeral peer(s): %d", count) + log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count) } -func (e *EphemeralManager) cleanup() { +func (e *EphemeralManager) cleanup(ctx context.Context) { log.Tracef("on ephemeral cleanup") deletePeers := make(map[string]*ephemeralPeer) @@ -154,7 +159,9 @@ func (e *EphemeralManager) cleanup() { } if e.headPeer != nil { - e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup) + e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() { + e.cleanup(ctx) + }) } else { e.timer = nil } @@ -162,10 +169,10 @@ func (e *EphemeralManager) cleanup() { e.peersLock.Unlock() for id, p := range deletePeers { - log.Debugf("delete ephemeral peer: %s", id) - err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator) + log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id) + err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator) if err != nil { - log.Errorf("failed to delete ephemeral peer: %s", err) + log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) } } } diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index 3e36335e3..36c88f1d1 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "testing" "time" @@ -13,11 +14,11 @@ type MockStore struct { account *Account } -func (s *MockStore) GetAllAccounts() []*Account { +func (s *MockStore) GetAllAccounts(_ context.Context) []*Account { return []*Account{s.account} } -func (s *MockStore) GetAccountByPeerID(peerId string) (*Account, error) { +func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) { _, ok := s.account.Peers[peerId] if ok { return s.account, nil @@ -31,7 +32,7 @@ type MocAccountManager struct { store *MockStore } -func (a MocAccountManager) DeletePeer(accountID, peerID, userID string) error { +func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error { delete(a.store.account.Peers, peerID) return nil //nolint:nil } @@ -52,9 +53,9 @@ func TestNewManager(t *testing.T) { seedPeers(store, numberOfPeers, numberOfEphemeralPeers) mgr := NewEphemeralManager(store, am) - mgr.loadEphemeralPeers() + mgr.loadEphemeralPeers(context.Background()) startTime = startTime.Add(ephemeralLifeTime + 1) - mgr.cleanup() + mgr.cleanup(context.Background()) if len(store.account.Peers) != numberOfPeers { t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers)) @@ -77,11 +78,11 @@ func TestNewManagerPeerConnected(t *testing.T) { seedPeers(store, numberOfPeers, numberOfEphemeralPeers) mgr := NewEphemeralManager(store, am) - mgr.loadEphemeralPeers() - mgr.OnPeerConnected(store.account.Peers["ephemeral_peer_0"]) + mgr.loadEphemeralPeers(context.Background()) + mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) startTime = startTime.Add(ephemeralLifeTime + 1) - mgr.cleanup() + mgr.cleanup(context.Background()) expected := numberOfPeers + 1 if len(store.account.Peers) != expected { @@ -105,15 +106,15 @@ func TestNewManagerPeerDisconnected(t *testing.T) { seedPeers(store, numberOfPeers, numberOfEphemeralPeers) mgr := NewEphemeralManager(store, am) - mgr.loadEphemeralPeers() + mgr.loadEphemeralPeers(context.Background()) for _, v := range store.account.Peers { - mgr.OnPeerConnected(v) + mgr.OnPeerConnected(context.Background(), v) } - mgr.OnPeerDisconnected(store.account.Peers["ephemeral_peer_0"]) + mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) startTime = startTime.Add(ephemeralLifeTime + 1) - mgr.cleanup() + mgr.cleanup(context.Background()) expected := numberOfPeers + numberOfEphemeralPeers - 1 if len(store.account.Peers) != expected { @@ -122,7 +123,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) { } func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) { - store.account = newAccountWithId("my account", "", "") + store.account = newAccountWithId(context.Background(), "my account", "", "") for i := 0; i < numberOfPeers; i++ { peerId := fmt.Sprintf("peer_%d", i) diff --git a/management/server/event.go b/management/server/event.go index 303f88a79..616cea287 100644 --- a/management/server/event.go +++ b/management/server/event.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "time" @@ -11,11 +12,11 @@ import ( ) // GetEvents returns a list of activity events of an account -func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -29,7 +30,7 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view events") } - events, err := am.eventStore.Get(accountID, 0, 10000, true) + events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true) if err != nil { return nil, err } @@ -54,10 +55,10 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit return filtered, nil } -func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { +func (am *DefaultAccountManager) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { go func() { - _, err := am.eventStore.Save(&activity.Event{ + _, err := am.eventStore.Save(ctx, &activity.Event{ Timestamp: time.Now().UTC(), Activity: activityID, InitiatorID: initiatorID, @@ -67,7 +68,7 @@ func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID str }) if err != nil { // todo add metric - log.Errorf("received an error while storing an activity event, error: %s", err) + log.WithContext(ctx).Errorf("received an error while storing an activity event, error: %s", err) } }() diff --git a/management/server/event_test.go b/management/server/event_test.go index 401c80759..8c56fd3f6 100644 --- a/management/server/event_test.go +++ b/management/server/event_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "testing" "time" @@ -13,7 +14,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac accountID string, count int) { t.Helper() for i := 0; i < count; i++ { - _, err := manager.eventStore.Save(&activity.Event{ + _, err := manager.eventStore.Save(context.Background(), &activity.Event{ Timestamp: time.Now().UTC(), Activity: typ, InitiatorID: initiatorID, @@ -35,32 +36,32 @@ func TestDefaultAccountManager_GetEvents(t *testing.T) { accountID := "accountID" t.Run("get empty events list", func(t *testing.T) { - events, err := manager.GetEvents(accountID, userID) + events, err := manager.GetEvents(context.Background(), accountID, userID) if err != nil { return } assert.Len(t, events, 0) - _ = manager.eventStore.Close() //nolint + _ = manager.eventStore.Close(context.Background()) //nolint }) t.Run("get events", func(t *testing.T) { generateAndStoreEvents(t, manager, activity.PeerAddedByUser, userID, "peer", accountID, 10) - events, err := manager.GetEvents(accountID, userID) + events, err := manager.GetEvents(context.Background(), accountID, userID) if err != nil { return } assert.Len(t, events, 10) - _ = manager.eventStore.Close() //nolint + _ = manager.eventStore.Close(context.Background()) //nolint }) t.Run("get events without duplicates", func(t *testing.T) { generateAndStoreEvents(t, manager, activity.UserJoined, userID, "", accountID, 10) - events, err := manager.GetEvents(accountID, userID) + events, err := manager.GetEvents(context.Background(), accountID, userID) if err != nil { return } assert.Len(t, events, 1) - _ = manager.eventStore.Close() //nolint + _ = manager.eventStore.Close(context.Background()) //nolint }) } diff --git a/management/server/file_store.go b/management/server/file_store.go index 60497824c..3fd543797 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -1,6 +1,7 @@ package server import ( + "context" "os" "path/filepath" "strings" @@ -48,8 +49,8 @@ type FileStore struct { type StoredAccount struct{} // NewFileStore restores a store from the file located in the datadir -func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { - fs, err := restore(filepath.Join(dataDir, storeFileName)) +func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { + fs, err := restore(ctx, filepath.Join(dataDir, storeFileName)) if err != nil { return nil, err } @@ -58,27 +59,27 @@ func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, err } // NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir -func NewFilestoreFromSqliteStore(sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { - store, err := NewFileStore(dataDir, metrics) +func NewFilestoreFromSqliteStore(ctx context.Context, sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { + store, err := NewFileStore(ctx, dataDir, metrics) if err != nil { return nil, err } - err = store.SaveInstallationID(sqlStore.GetInstallationID()) + err = store.SaveInstallationID(ctx, sqlStore.GetInstallationID()) if err != nil { return nil, err } - for _, account := range sqlStore.GetAllAccounts() { + for _, account := range sqlStore.GetAllAccounts(ctx) { store.Accounts[account.Id] = account } - return store, store.persist(store.storeFile) + return store, store.persist(ctx, store.storeFile) } // restore the state of the store from the file. // Creates a new empty store file if doesn't exist -func restore(file string) (*FileStore, error) { +func restore(ctx context.Context, file string) (*FileStore, error) { if _, err := os.Stat(file); os.IsNotExist(err) { // create a new FileStore if previously didn't exist (e.g. first run) s := &FileStore{ @@ -95,7 +96,7 @@ func restore(file string) (*FileStore, error) { storeFile: file, } - err = s.persist(file) + err = s.persist(ctx, file) if err != nil { return nil, err } @@ -165,7 +166,7 @@ func restore(file string) (*FileStore, error) { // for data migration. Can be removed once most base will be with labels existingLabels := account.getPeerDNSLabels() if len(existingLabels) != len(account.Peers) { - addPeerLabelsToAccount(account, existingLabels) + addPeerLabelsToAccount(ctx, account, existingLabels) } // TODO: delete this block after migration @@ -178,7 +179,7 @@ func restore(file string) (*FileStore, error) { allGroup, err := account.GetGroupAll() if err != nil { - log.Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err) + log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err) // if the All group didn't exist we probably don't have routes to update continue } @@ -236,7 +237,7 @@ func restore(file string) (*FileStore, error) { } // we need this persist to apply changes we made to account.Peers (we set them to Disconnected) - err = store.persist(store.storeFile) + err = store.persist(ctx, store.storeFile) if err != nil { return nil, err } @@ -246,7 +247,7 @@ func restore(file string) (*FileStore, error) { // persist account data to a file // It is recommended to call it with locking FileStore.mux -func (s *FileStore) persist(file string) error { +func (s *FileStore) persist(ctx context.Context, file string) error { start := time.Now() err := util.WriteJson(file, s) if err != nil { @@ -256,23 +257,23 @@ func (s *FileStore) persist(file string) error { if s.metrics != nil { s.metrics.StoreMetrics().CountPersistenceDuration(took) } - log.Debugf("took %d ms to persist the FileStore", took.Milliseconds()) + log.WithContext(ctx).Debugf("took %d ms to persist the FileStore", took.Milliseconds()) return nil } // AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock -func (s *FileStore) AcquireGlobalLock() (unlock func()) { - log.Debugf("acquiring global lock") +func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) { + log.WithContext(ctx).Debugf("acquiring global lock") start := time.Now() s.globalAccountLock.Lock() unlock = func() { s.globalAccountLock.Unlock() - log.Debugf("released global lock in %v", time.Since(start)) + log.WithContext(ctx).Debugf("released global lock in %v", time.Since(start)) } took := time.Since(start) - log.Debugf("took %v to acquire global lock", took) + log.WithContext(ctx).Debugf("took %v to acquire global lock", took) if s.metrics != nil { s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took) } @@ -281,8 +282,8 @@ func (s *FileStore) AcquireGlobalLock() (unlock func()) { } // AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock -func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) { - log.Debugf("acquiring lock for account %s", accountID) +func (s *FileStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) { + log.WithContext(ctx).Debugf("acquiring lock for account %s", accountID) start := time.Now() value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) mtx := value.(*sync.Mutex) @@ -290,7 +291,7 @@ func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) { unlock = func() { mtx.Unlock() - log.Debugf("released lock for account %s in %v", accountID, time.Since(start)) + log.WithContext(ctx).Debugf("released lock for account %s in %v", accountID, time.Since(start)) } return unlock @@ -298,11 +299,11 @@ func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) { // AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock // This method is still returns a write lock as file store can't handle read locks -func (s *FileStore) AcquireAccountReadLock(accountID string) (unlock func()) { - return s.AcquireAccountWriteLock(accountID) +func (s *FileStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) { + return s.AcquireAccountWriteLock(ctx, accountID) } -func (s *FileStore) SaveAccount(account *Account) error { +func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error { s.mux.Lock() defer s.mux.Unlock() @@ -338,10 +339,10 @@ func (s *FileStore) SaveAccount(account *Account) error { s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id } - return s.persist(s.storeFile) + return s.persist(ctx, s.storeFile) } -func (s *FileStore) DeleteAccount(account *Account) error { +func (s *FileStore) DeleteAccount(ctx context.Context, account *Account) error { s.mux.Lock() defer s.mux.Unlock() @@ -373,7 +374,7 @@ func (s *FileStore) DeleteAccount(account *Account) error { delete(s.Accounts, account.Id) - return s.persist(s.storeFile) + return s.persist(ctx, s.storeFile) } // DeleteHashedPAT2TokenIDIndex removes an entry from the indexing map HashedPAT2TokenID @@ -397,7 +398,7 @@ func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error { } // GetAccountByPrivateDomain returns account by private domain -func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) { +func (s *FileStore) GetAccountByPrivateDomain(_ context.Context, domain string) (*Account, error) { s.mux.Lock() defer s.mux.Unlock() @@ -415,7 +416,7 @@ func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) { } // GetAccountBySetupKey returns account by setup key id -func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) { +func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*Account, error) { s.mux.Lock() defer s.mux.Unlock() @@ -433,7 +434,7 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) { } // GetTokenIDByHashedToken returns the id of a personal access token by its hashed secret -func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) { +func (s *FileStore) GetTokenIDByHashedToken(_ context.Context, token string) (string, error) { s.mux.Lock() defer s.mux.Unlock() @@ -446,7 +447,7 @@ func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) { } // GetUserByTokenID returns a User object a tokenID belongs to -func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) { +func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User, error) { s.mux.Lock() defer s.mux.Unlock() @@ -469,7 +470,7 @@ func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) { } // GetAllAccounts returns all accounts -func (s *FileStore) GetAllAccounts() (all []*Account) { +func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { s.mux.Lock() defer s.mux.Unlock() for _, a := range s.Accounts { @@ -490,7 +491,7 @@ func (s *FileStore) getAccount(accountID string) (*Account, error) { } // GetAccount returns an account for ID -func (s *FileStore) GetAccount(accountID string) (*Account, error) { +func (s *FileStore) GetAccount(_ context.Context, accountID string) (*Account, error) { s.mux.Lock() defer s.mux.Unlock() @@ -503,7 +504,7 @@ func (s *FileStore) GetAccount(accountID string) (*Account, error) { } // GetAccountByUser returns a user account -func (s *FileStore) GetAccountByUser(userID string) (*Account, error) { +func (s *FileStore) GetAccountByUser(_ context.Context, userID string) (*Account, error) { s.mux.Lock() defer s.mux.Unlock() @@ -521,7 +522,7 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) { } // GetAccountByPeerID returns an account for a given peer ID -func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) { +func (s *FileStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { s.mux.Lock() defer s.mux.Unlock() @@ -539,7 +540,7 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) { // check Account.Peers for a match if _, ok := account.Peers[peerID]; !ok { delete(s.PeerID2AccountID, peerID) - log.Warnf("removed stale peerID %s to accountID %s index", peerID, accountID) + log.WithContext(ctx).Warnf("removed stale peerID %s to accountID %s index", peerID, accountID) return nil, status.NewPeerNotFoundError(peerID) } @@ -547,7 +548,7 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) { } // GetAccountByPeerPubKey returns an account for a given peer WireGuard public key -func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { +func (s *FileStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { s.mux.Lock() defer s.mux.Unlock() @@ -572,14 +573,14 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { } if stale { delete(s.PeerKeyID2AccountID, peerKey) - log.Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID) + log.WithContext(ctx).Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID) return nil, status.NewPeerNotFoundError(peerKey) } return account.Copy(), nil } -func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) { +func (s *FileStore) GetAccountIDByPeerPubKey(_ context.Context, peerKey string) (string, error) { s.mux.Lock() defer s.mux.Unlock() @@ -603,7 +604,7 @@ func (s *FileStore) GetAccountIDByUserID(userID string) (string, error) { return accountID, nil } -func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) { +func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (string, error) { s.mux.Lock() defer s.mux.Unlock() @@ -615,7 +616,7 @@ func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) { return accountID, nil } -func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) { +func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) { s.mux.Lock() defer s.mux.Unlock() @@ -638,7 +639,7 @@ func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) { return nil, status.NewPeerNotFoundError(peerKey) } -func (s *FileStore) GetAccountSettings(accountID string) (*Settings, error) { +func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) { s.mux.Lock() defer s.mux.Unlock() @@ -656,13 +657,13 @@ func (s *FileStore) GetInstallationID() string { } // SaveInstallationID saves the installation ID -func (s *FileStore) SaveInstallationID(ID string) error { +func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error { s.mux.Lock() defer s.mux.Unlock() s.InstallationID = ID - return s.persist(s.storeFile) + return s.persist(ctx, s.storeFile) } // SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things. @@ -732,13 +733,13 @@ func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks * } // Close the FileStore persisting data to disk -func (s *FileStore) Close() error { +func (s *FileStore) Close(ctx context.Context) error { s.mux.Lock() defer s.mux.Unlock() - log.Infof("closing FileStore") + log.WithContext(ctx).Infof("closing FileStore") - return s.persist(s.storeFile) + return s.persist(ctx, s.storeFile) } // GetStoreEngine returns FileStoreEngine diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go index 11571b0be..56e46b696 100644 --- a/management/server/file_store_test.go +++ b/management/server/file_store_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "crypto/sha256" "net" "path/filepath" @@ -27,12 +28,12 @@ func TestStalePeerIndices(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { return } - account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") + account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) peerID := "some_peer" @@ -42,24 +43,24 @@ func TestStalePeerIndices(t *testing.T) { Key: peerKey, } - err = store.SaveAccount(account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) account.DeletePeer(peerID) - err = store.SaveAccount(account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) - _, err = store.GetAccountByPeerID(peerID) + _, err = store.GetAccountByPeerID(context.Background(), peerID) require.Error(t, err, "expecting to get an error when found stale index") - _, err = store.GetAccountByPeerPubKey(peerKey) + _, err = store.GetAccountByPeerPubKey(context.Background(), peerKey) require.Error(t, err, "expecting to get an error when found stale index") } func TestNewStore(t *testing.T) { store := newStore(t) - defer store.Close() + defer store.Close(context.Background()) if store.Accounts == nil || len(store.Accounts) != 0 { t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") @@ -88,9 +89,9 @@ func TestNewStore(t *testing.T) { func TestSaveAccount(t *testing.T) { store := newStore(t) - defer store.Close() + defer store.Close(context.Background()) - account := newAccountWithId("account_id", "testuser", "") + account := newAccountWithId(context.Background(), "account_id", "testuser", "") setupKey := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ @@ -103,7 +104,7 @@ func TestSaveAccount(t *testing.T) { } // SaveAccount should trigger persist - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { return } @@ -133,11 +134,11 @@ func TestDeleteAccount(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { t.Fatal(err) } - defer store.Close() + defer store.Close(context.Background()) var account *Account for _, a := range store.Accounts { @@ -147,7 +148,7 @@ func TestDeleteAccount(t *testing.T) { require.NotNil(t, account, "failed to restore a FileStore file and get at least one account") - err = store.DeleteAccount(account) + err = store.DeleteAccount(context.Background(), account) require.NoError(t, err, "failed to delete account, error: %v", err) _, ok := store.Accounts[account.Id] @@ -183,9 +184,9 @@ func TestDeleteAccount(t *testing.T) { func TestStore(t *testing.T) { store := newStore(t) - defer store.Close() + defer store.Close(context.Background()) - account := newAccountWithId("account_id", "testuser", "") + account := newAccountWithId(context.Background(), "account_id", "testuser", "") account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", SetupKey: "peerkeysetupkey", @@ -228,12 +229,12 @@ func TestStore(t *testing.T) { }) // SaveAccount should trigger persist - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { return } - restored, err := NewFileStore(store.storeFile, nil) + restored, err := NewFileStore(context.Background(), store.storeFile, nil) if err != nil { return } @@ -281,7 +282,7 @@ func TestRestore(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { return } @@ -319,7 +320,7 @@ func TestRestoreGroups_Migration(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { return } @@ -332,11 +333,11 @@ func TestRestoreGroups_Migration(t *testing.T) { Name: "All", }, } - err = store.SaveAccount(account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err, "failed to save account") // restore account with default group with empty Issue field - if store, err = NewFileStore(storeDir, nil); err != nil { + if store, err = NewFileStore(context.Background(), storeDir, nil); err != nil { return } account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] @@ -353,18 +354,18 @@ func TestGetAccountByPrivateDomain(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { return } existingDomain := "test.com" - account, err := store.GetAccountByPrivateDomain(existingDomain) + account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) require.NoError(t, err, "should found account") require.Equal(t, existingDomain, account.Domain, "domains should match") - _, err = store.GetAccountByPrivateDomain("missing-domain.com") + _, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com") require.Error(t, err, "should return error on domain lookup") } @@ -382,7 +383,7 @@ func TestFileStore_GetAccount(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { t.Fatal(err) } @@ -393,7 +394,7 @@ func TestFileStore_GetAccount(t *testing.T) { return } - account, err := store.GetAccount(expected.Id) + account, err := store.GetAccount(context.Background(), expected.Id) if err != nil { t.Fatal(err) } @@ -424,13 +425,13 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { t.Fatal(err) } hashedToken := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].HashedToken - tokenID, err := store.GetTokenIDByHashedToken(hashedToken) + tokenID, err := store.GetTokenIDByHashedToken(context.Background(), hashedToken) if err != nil { t.Fatal(err) } @@ -441,7 +442,7 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) { func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) { store := newStore(t) - defer store.Close() + defer store.Close(context.Background()) store.HashedPAT2TokenID["someHashedToken"] = "someTokenId" err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken") @@ -478,13 +479,13 @@ func TestFileStore_GetTokenIDByHashedToken_Failure(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { t.Fatal(err) } wrongToken := sha256.Sum256([]byte("someNotValidTokenThatFails1234")) - _, err = store.GetTokenIDByHashedToken(string(wrongToken[:])) + _, err = store.GetTokenIDByHashedToken(context.Background(), string(wrongToken[:])) assert.Error(t, err, "GetTokenIDByHashedToken should throw error if token invalid") } @@ -503,13 +504,13 @@ func TestFileStore_GetUserByTokenID(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { t.Fatal(err) } tokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID - user, err := store.GetUserByTokenID(tokenID) + user, err := store.GetUserByTokenID(context.Background(), tokenID) if err != nil { t.Fatal(err) } @@ -531,13 +532,13 @@ func TestFileStore_GetUserByTokenID_Failure(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { t.Fatal(err) } wrongTokenID := "someNonExistingTokenID" - _, err = store.GetUserByTokenID(wrongTokenID) + _, err = store.GetUserByTokenID(context.Background(), wrongTokenID) assert.Error(t, err, "GetUserByTokenID should throw error if tokenID invalid") } @@ -550,7 +551,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { return } @@ -576,7 +577,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, } - err = store.SaveAccount(account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatal(err) } @@ -602,11 +603,11 @@ func TestFileStore_SavePeerLocation(t *testing.T) { t.Fatal(err) } - store, err := NewFileStore(storeDir, nil) + store, err := NewFileStore(context.Background(), storeDir, nil) if err != nil { return } - account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") + account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) peer := &nbpeer.Peer{ @@ -625,7 +626,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) { assert.Error(t, err) account.Peers[peer.ID] = peer - err = store.SaveAccount(account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) peer.Location.ConnectionIP = net.ParseIP("35.1.1.1") @@ -636,7 +637,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) { err = store.SavePeerLocation(account.Id, account.Peers[peer.ID]) assert.NoError(t, err) - account, err = store.GetAccount(account.Id) + account, err = store.GetAccount(context.Background(), account.Id) require.NoError(t, err) actual := account.Peers[peer.ID].Location @@ -645,7 +646,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) { func newStore(t *testing.T) *FileStore { t.Helper() - store, err := NewFileStore(t.TempDir(), nil) + store, err := NewFileStore(context.Background(), t.TempDir(), nil) if err != nil { t.Errorf("failed creating a new store") } diff --git a/management/server/geolocation/geolocation.go b/management/server/geolocation/geolocation.go index 4fd28806b..794f9d0be 100644 --- a/management/server/geolocation/geolocation.go +++ b/management/server/geolocation/geolocation.go @@ -2,6 +2,7 @@ package geolocation import ( "bytes" + "context" "fmt" "net" "os" @@ -52,7 +53,7 @@ type Country struct { CountryName string } -func NewGeolocation(dataDir string) (*Geolocation, error) { +func NewGeolocation(ctx context.Context, dataDir string) (*Geolocation, error) { if err := loadGeolocationDatabases(dataDir); err != nil { return nil, fmt.Errorf("failed to load MaxMind databases: %v", err) } @@ -68,7 +69,7 @@ func NewGeolocation(dataDir string) (*Geolocation, error) { return nil, err } - locationDB, err := NewSqliteStore(dataDir) + locationDB, err := NewSqliteStore(ctx, dataDir) if err != nil { return nil, err } @@ -83,7 +84,7 @@ func NewGeolocation(dataDir string) (*Geolocation, error) { stopCh: make(chan struct{}), } - go geo.reloader() + go geo.reloader(ctx) return geo, nil } @@ -165,19 +166,19 @@ func (gl *Geolocation) Stop() error { return nil } -func (gl *Geolocation) reloader() { +func (gl *Geolocation) reloader(ctx context.Context) { for { select { case <-gl.stopCh: return case <-time.After(gl.reloadCheckInterval): - if err := gl.locationDB.reload(); err != nil { - log.Errorf("geonames db reload failed: %s", err) + if err := gl.locationDB.reload(ctx); err != nil { + log.WithContext(ctx).Errorf("geonames db reload failed: %s", err) } newSha256sum1, err := calculateFileSHA256(gl.mmdbPath) if err != nil { - log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err) + log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err) continue } if !bytes.Equal(gl.sha256sum, newSha256sum1) { @@ -186,30 +187,30 @@ func (gl *Geolocation) reloader() { time.Sleep(50 * time.Millisecond) newSha256sum2, err := calculateFileSHA256(gl.mmdbPath) if err != nil { - log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err) + log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err) continue } if !bytes.Equal(newSha256sum1, newSha256sum2) { - log.Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath) + log.WithContext(ctx).Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath) continue } - err = gl.reload(newSha256sum2) + err = gl.reload(ctx, newSha256sum2) if err != nil { - log.Errorf("mmdb reload failed: %s", err) + log.WithContext(ctx).Errorf("mmdb reload failed: %s", err) } } else { - log.Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.", + log.WithContext(ctx).Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.", gl.mmdbPath, gl.reloadCheckInterval.Seconds()) } } } } -func (gl *Geolocation) reload(newSha256sum []byte) error { +func (gl *Geolocation) reload(ctx context.Context, newSha256sum []byte) error { gl.mux.Lock() defer gl.mux.Unlock() - log.Infof("Reloading '%s'", gl.mmdbPath) + log.WithContext(ctx).Infof("Reloading '%s'", gl.mmdbPath) err := gl.db.Close() if err != nil { @@ -224,7 +225,7 @@ func (gl *Geolocation) reload(newSha256sum []byte) error { gl.db = db gl.sha256sum = newSha256sum - log.Infof("Successfully reloaded '%s'", gl.mmdbPath) + log.WithContext(ctx).Infof("Successfully reloaded '%s'", gl.mmdbPath) return nil } diff --git a/management/server/geolocation/store.go b/management/server/geolocation/store.go index 3da7989e1..67d420cfd 100644 --- a/management/server/geolocation/store.go +++ b/management/server/geolocation/store.go @@ -2,6 +2,7 @@ package geolocation import ( "bytes" + "context" "fmt" "path/filepath" "runtime" @@ -50,10 +51,10 @@ type SqliteStore struct { sha256sum []byte } -func NewSqliteStore(dataDir string) (*SqliteStore, error) { +func NewSqliteStore(ctx context.Context, dataDir string) (*SqliteStore, error) { file := filepath.Join(dataDir, GeoSqliteDBFile) - db, err := connectDB(file) + db, err := connectDB(ctx, file) if err != nil { return nil, err } @@ -115,13 +116,13 @@ func (s *SqliteStore) GetCitiesByCountry(countryISOCode string) ([]City, error) } // reload attempts to reload the SqliteStore's database if the database file has changed. -func (s *SqliteStore) reload() error { +func (s *SqliteStore) reload(ctx context.Context) error { s.mux.Lock() defer s.mux.Unlock() newSha256sum1, err := calculateFileSHA256(s.filePath) if err != nil { - log.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err) + log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err) } if !bytes.Equal(s.sha256sum, newSha256sum1) { @@ -136,11 +137,11 @@ func (s *SqliteStore) reload() error { return fmt.Errorf("sha256 sum changed during reloading of '%s'", s.filePath) } - log.Infof("Reloading '%s'", s.filePath) + log.WithContext(ctx).Infof("Reloading '%s'", s.filePath) _ = s.close() s.closed = true - newDb, err := connectDB(s.filePath) + newDb, err := connectDB(ctx, s.filePath) if err != nil { return err } @@ -148,9 +149,9 @@ func (s *SqliteStore) reload() error { s.closed = false s.db = newDb - log.Infof("Successfully reloaded '%s'", s.filePath) + log.WithContext(ctx).Infof("Successfully reloaded '%s'", s.filePath) } else { - log.Tracef("No changes in '%s', no need to reload", s.filePath) + log.WithContext(ctx).Tracef("No changes in '%s', no need to reload", s.filePath) } return nil @@ -168,10 +169,10 @@ func (s *SqliteStore) close() error { } // connectDB connects to an SQLite database and prepares it by setting up an in-memory database. -func connectDB(filePath string) (*gorm.DB, error) { +func connectDB(ctx context.Context, filePath string) (*gorm.DB, error) { start := time.Now() defer func() { - log.Debugf("took %v to setup geoname db", time.Since(start)) + log.WithContext(ctx).Debugf("took %v to setup geoname db", time.Since(start)) }() _, err := fileExists(filePath) diff --git a/management/server/group.go b/management/server/group.go index 7ede2120d..ea512924b 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "github.com/rs/xid" @@ -21,11 +22,11 @@ func (e *GroupLinkError) Error() string { } // GetGroup object of the peers -func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -48,11 +49,11 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*n } // GetAllGroups returns all groups in an account -func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -75,11 +76,11 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ( } // GetGroupByName filters all groups in an account by name and returns the one with the most peers -func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -108,11 +109,11 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*n } // SaveGroup object of the peers -func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -150,11 +151,11 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *n account.Groups[newGroup.ID] = newGroup account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) // the following snippet tracks the activity and stores the group events in the event store. // It has to happen after all the operations have been successfully performed. @@ -165,16 +166,16 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *n removedPeers = difference(oldGroup.Peers, newGroup.Peers) } else { addedPeers = append(addedPeers, newGroup.Peers...) - am.StoreEvent(userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta()) + am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta()) } for _, p := range addedPeers { peer := account.Peers[p] if peer == nil { - log.Errorf("peer %s not found under account %s while saving group", p, accountID) + log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) continue } - am.StoreEvent(userID, peer.ID, accountID, activity.GroupAddedToPeer, + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, map[string]any{ "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), @@ -184,10 +185,10 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *n for _, p := range removedPeers { peer := account.Peers[p] if peer == nil { - log.Errorf("peer %s not found under account %s while saving group", p, accountID) + log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) continue } - am.StoreEvent(userID, peer.ID, accountID, activity.GroupRemovedFromPeer, + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, map[string]any{ "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), @@ -213,11 +214,11 @@ func difference(a, b []string) []string { } // DeleteGroup object of the peers -func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountId) +func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountId) defer unlock() - account, err := am.Store.GetAccount(accountId) + account, err := am.Store.GetAccount(ctx, accountId) if err != nil { return err } @@ -315,23 +316,23 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) delete(account.Groups, groupID) account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.StoreEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta()) + am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, g.EventMeta()) - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } // ListGroups objects of the peers -func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -345,11 +346,11 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, } // GroupAddPeer appends peer to the group -func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -371,21 +372,21 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) } account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } // GroupDeletePeer removes peer from the group -func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -399,13 +400,13 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID stri for i, itemID := range group.Peers { if itemID == peerID { group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) - if err := am.Store.SaveAccount(account); err != nil { + if err := am.Store.SaveAccount(ctx, account); err != nil { return err } } } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } diff --git a/management/server/group_test.go b/management/server/group_test.go index 1c718715d..373d72964 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "errors" "testing" @@ -26,7 +27,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { } for _, group := range account.Groups { group.Issued = nbgroup.GroupIssuedIntegration - err = am.SaveGroup(account.Id, groupAdminUserID, group) + err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err != nil { t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration) } @@ -34,7 +35,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { for _, group := range account.Groups { group.Issued = nbgroup.GroupIssuedJWT - err = am.SaveGroup(account.Id, groupAdminUserID, group) + err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err != nil { t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT) } @@ -42,7 +43,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { for _, group := range account.Groups { group.Issued = nbgroup.GroupIssuedAPI group.ID = "" - err = am.SaveGroup(account.Id, groupAdminUserID, group) + err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err == nil { t.Errorf("should not create api group with the same name, %s", group.Name) } @@ -104,7 +105,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - err = am.DeleteGroup(account.Id, groupAdminUserID, testCase.groupID) + err = am.DeleteGroup(context.Background(), account.Id, groupAdminUserID, testCase.groupID) if err == nil { t.Errorf("delete %s group successfully", testCase.groupID) return @@ -225,7 +226,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) { Id: "example user", AutoGroups: []string{groupForUsers.ID}, } - account := newAccountWithId(accountID, groupAdminUserID, domain) + account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain) account.Routes[routeResource.ID] = routeResource account.Routes[routePeerGroupResource.ID] = routePeerGroupResource account.NameServerGroups[nameServerGroup.ID] = nameServerGroup @@ -233,18 +234,18 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) { account.SetupKeys[setupKey.Id] = setupKey account.Users[user.Id] = user - err := am.Store.SaveAccount(account) + err := am.Store.SaveAccount(context.Background(), account) if err != nil { return nil, err } - _ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute) - _ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute2) - _ = am.SaveGroup(accountID, groupAdminUserID, groupForNameServerGroups) - _ = am.SaveGroup(accountID, groupAdminUserID, groupForPolicies) - _ = am.SaveGroup(accountID, groupAdminUserID, groupForSetupKeys) - _ = am.SaveGroup(accountID, groupAdminUserID, groupForUsers) - _ = am.SaveGroup(accountID, groupAdminUserID, groupForIntegration) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers) + _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration) - return am.Store.GetAccount(account.Id) + return am.Store.GetAccount(context.Background(), account.Id) } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index bf0c3009a..170e72dd0 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -11,12 +11,14 @@ import ( pb "github.com/golang/protobuf/proto" // nolint "github.com/golang/protobuf/ptypes/timestamp" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" - "github.com/netbirdio/netbird/management/server/posture" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + nbContext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -40,7 +42,7 @@ type GRPCServer struct { } // NewServer creates a new Management server -func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) { +func NewServer(ctx context.Context, config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { return nil, err @@ -50,6 +52,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { jwtValidator, err = jwtclaims.NewJWTValidator( + ctx, config.HttpConfig.AuthIssuer, config.GetAuthAudiences(), config.HttpConfig.AuthKeysLocation, @@ -59,7 +62,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err) } } else { - log.Debug("unable to use http config to create new jwt middleware") + log.WithContext(ctx).Debug("unable to use http config to create new jwt middleware") } if appMetrics != nil { @@ -126,47 +129,61 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequest() } - realIP := getRealIP(srv.Context()) - log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String()) + + ctx := srv.Context() + + realIP := getRealIP(ctx) syncReq := &proto.SyncRequest{} - peerKey, err := s.parseRequest(req, syncReq) + peerKey, err := s.parseRequest(ctx, req, syncReq) if err != nil { return err } + //nolint + ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) + accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) + if err != nil { + // this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail + accountID = "UNKNOWN" + } + //nolint + ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + + log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String()) + if syncReq.GetMeta() == nil { - log.Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) + log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) } - peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), extractPeerMeta(syncReq.GetMeta()), realIP) + peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP) if err != nil { - return mapError(err) + return mapError(ctx, err) } - err = s.sendInitialSync(peerKey, peer, netMap, postureChecks, srv) + err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv) if err != nil { - log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) + log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) return err } - updates := s.peersUpdateManager.CreateChannel(peer.ID) + updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID) - s.ephemeralManager.OnPeerConnected(peer) + s.ephemeralManager.OnPeerConnected(ctx, peer) if s.config.TURNConfig.TimeBasedCredentials { - s.turnCredentialsManager.SetupRefresh(peer.ID) + s.turnCredentialsManager.SetupRefresh(ctx, peer.ID) } if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart)) } - return s.handleUpdates(peerKey, peer, updates, srv) + return s.handleUpdates(ctx, peerKey, peer, updates, srv) } // handleUpdates sends updates to the connected peer until the updates channel is closed. -func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { for { select { // condition when there are some updates @@ -176,21 +193,21 @@ func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updat } if !open { - log.Debugf("updates channel for peer %s was closed", peerKey.String()) - s.cancelPeerRoutines(peer) + log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String()) + s.cancelPeerRoutines(ctx, peer) return nil } - log.Debugf("received an update for peer %s", peerKey.String()) + log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) - if err := s.sendUpdate(peerKey, peer, update, srv); err != nil { + if err := s.sendUpdate(ctx, peerKey, peer, update, srv); err != nil { return err } // condition when client <-> server connection has been terminated case <-srv.Context().Done(): // happens when connection drops, e.g. client disconnects - log.Debugf("stream of peer %s has been closed", peerKey.String()) - s.cancelPeerRoutines(peer) + log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String()) + s.cancelPeerRoutines(ctx, peer) return srv.Context().Err() } } @@ -198,10 +215,10 @@ func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updat // sendUpdate encrypts the update message using the peer key and the server's wireguard key, // then sends the encrypted message to the connected peer via the sync server. -func (s *GRPCServer) sendUpdate(peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error { encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) if err != nil { - s.cancelPeerRoutines(peer) + s.cancelPeerRoutines(ctx, peer) return status.Errorf(codes.Internal, "failed processing update message") } err = srv.SendMsg(&proto.EncryptedMessage{ @@ -209,37 +226,37 @@ func (s *GRPCServer) sendUpdate(peerKey wgtypes.Key, peer *nbpeer.Peer, update * Body: encryptedResp, }) if err != nil { - s.cancelPeerRoutines(peer) + s.cancelPeerRoutines(ctx, peer) return status.Errorf(codes.Internal, "failed sending update message") } - log.Debugf("sent an update to peer %s", peerKey.String()) + log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String()) return nil } -func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) { - s.peersUpdateManager.CloseChannel(peer.ID) +func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) { + s.peersUpdateManager.CloseChannel(ctx, peer.ID) s.turnCredentialsManager.CancelRefresh(peer.ID) - _ = s.accountManager.CancelPeerRoutines(peer) - s.ephemeralManager.OnPeerDisconnected(peer) + _ = s.accountManager.CancelPeerRoutines(ctx, peer) + s.ephemeralManager.OnPeerDisconnected(ctx, peer) } -func (s *GRPCServer) validateToken(jwtToken string) (string, error) { +func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) { if s.jwtValidator == nil { return "", status.Error(codes.Internal, "no jwt validator set") } - token, err := s.jwtValidator.ValidateAndParse(jwtToken) + token, err := s.jwtValidator.ValidateAndParse(ctx, jwtToken) if err != nil { return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err) } claims := s.jwtClaimsExtractor.FromToken(token) // we need to call this method because if user is new, we will automatically add it to existing or create a new account - _, _, err = s.accountManager.GetAccountFromToken(claims) + _, _, err = s.accountManager.GetAccountFromToken(ctx, claims) if err != nil { return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) } - if err := s.accountManager.CheckUserAccessByJWTGroups(claims); err != nil { + if err := s.accountManager.CheckUserAccessByJWTGroups(ctx, claims); err != nil { return "", status.Errorf(codes.PermissionDenied, err.Error()) } @@ -247,7 +264,7 @@ func (s *GRPCServer) validateToken(jwtToken string) (string, error) { } // maps internal internalStatus.Error to gRPC status.Error -func mapError(err error) error { +func mapError(ctx context.Context, err error) error { if e, ok := internalStatus.FromError(err); ok { switch e.Type() { case internalStatus.PermissionDenied: @@ -263,11 +280,11 @@ func mapError(err error) error { default: } } - log.Errorf("got an unhandled error: %s", err) + log.WithContext(ctx).Errorf("got an unhandled error: %s", err) return status.Errorf(codes.Internal, "failed handling request") } -func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta { +func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta { if meta == nil { return nbpeer.PeerSystemMeta{} } @@ -281,7 +298,7 @@ func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta { for _, addr := range meta.GetNetworkAddresses() { netAddr, err := netip.ParsePrefix(addr.GetNetIP()) if err != nil { - log.Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err) + log.WithContext(ctx).Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err) continue } networkAddresses = append(networkAddresses, nbpeer.NetworkAddress{ @@ -321,10 +338,10 @@ func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta { } } -func (s *GRPCServer) parseRequest(req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { +func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { - log.Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey) + log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey) return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey) } @@ -351,22 +368,32 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p s.appMetrics.GRPCMetrics().CountLoginRequest() } realIP := getRealIP(ctx) - log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String()) + log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String()) loginReq := &proto.LoginRequest{} - peerKey, err := s.parseRequest(req, loginReq) + peerKey, err := s.parseRequest(ctx, req, loginReq) if err != nil { return nil, err } + //nolint + ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) + accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) + if err != nil { + // this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail + accountID = "UNKNOWN" + } + //nolint + ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + if loginReq.GetMeta() == nil { msg := status.Errorf(codes.FailedPrecondition, "peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP) - log.Warn(msg) + log.WithContext(ctx).Warn(msg) return nil, msg } - userID, err := s.processJwtToken(loginReq, peerKey) + userID, err := s.processJwtToken(ctx, loginReq, peerKey) if err != nil { return nil, err } @@ -376,33 +403,33 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p sshKey = loginReq.GetPeerKeys().GetSshPubKey() } - peer, netMap, postureChecks, err := s.accountManager.LoginPeer(PeerLogin{ + peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, PeerLogin{ WireGuardPubKey: peerKey.String(), SSHKey: string(sshKey), - Meta: extractPeerMeta(loginReq.GetMeta()), + Meta: extractPeerMeta(ctx, loginReq.GetMeta()), UserID: userID, SetupKey: loginReq.GetSetupKey(), ConnectionIP: realIP, }) if err != nil { - log.Warnf("failed logging in peer %s: %s", peerKey, err) - return nil, mapError(err) + log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err) + return nil, mapError(ctx, err) } // if the login request contains setup key then it is a registration request if loginReq.GetSetupKey() != "" { - s.ephemeralManager.OnPeerDisconnected(peer) + s.ephemeralManager.OnPeerDisconnected(ctx, peer) } // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ WiretrusteeConfig: toWiretrusteeConfig(s.config, nil), PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()), - Checks: toProtocolChecks(postureChecks), + Checks: toProtocolChecks(ctx, postureChecks), } encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) if err != nil { - log.Warnf("failed encrypting peer %s message", peer.ID) + log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID) return nil, status.Errorf(codes.Internal, "failed logging in peer") } @@ -417,16 +444,16 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p // // The user ID can be empty if the token is not provided, which is acceptable if the peer is already // registered or if it uses a setup key to register. -func (s *GRPCServer) processJwtToken(loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) { +func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) { userID := "" if loginReq.GetJwtToken() != "" { var err error for i := 0; i < 3; i++ { - userID, err = s.validateToken(loginReq.GetJwtToken()) + userID, err = s.validateToken(ctx, loginReq.GetJwtToken()) if err == nil { break } - log.Warnf("failed validating JWT token sent from peer %s with error %v. "+ + log.WithContext(ctx).Warnf("failed validating JWT token sent from peer %s with error %v. "+ "Trying again as it may be due to the IdP cache issue", peerKey.String(), err) time.Sleep(200 * time.Millisecond) } @@ -520,7 +547,7 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee return remotePeers } -func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse { +func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse { wtConfig := toWiretrusteeConfig(config, turnCredentials) pConfig := toPeerConfig(peer, networkMap.Network, dnsName) @@ -551,7 +578,7 @@ func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCred FirewallRules: firewallRules, FirewallRulesIsEmpty: len(firewallRules) == 0, }, - Checks: toProtocolChecks(checks), + Checks: toProtocolChecks(ctx, checks), } } @@ -561,7 +588,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em } // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization -func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { +func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { // make secret time based TURN credentials optional var turnCredentials *TURNCredentials if s.config.TURNConfig.TimeBasedCredentials { @@ -570,7 +597,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net } else { turnCredentials = nil } - plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks) + plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { @@ -583,7 +610,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net }) if err != nil { - log.Errorf("failed sending SyncResponse %v", err) + log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err) return status.Errorf(codes.Internal, "error handling request") } @@ -597,14 +624,14 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto. peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey) - log.Warn(errMSG) + log.WithContext(ctx).Warn(errMSG) return nil, status.Error(codes.InvalidArgument, errMSG) } err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{}) if err != nil { errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey) - log.Warn(errMSG) + log.WithContext(ctx).Warn(errMSG) return nil, status.Error(codes.InvalidArgument, errMSG) } @@ -645,18 +672,18 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto. // GetPKCEAuthorizationFlow returns a pkce authorization flow information // This is used for initiating an Oauth 2 pkce authorization grant flow // which will be used by our clients to Login -func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { +func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey) - log.Warn(errMSG) + log.WithContext(ctx).Warn(errMSG) return nil, status.Error(codes.InvalidArgument, errMSG) } err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{}) if err != nil { errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey) - log.Warn(errMSG) + log.WithContext(ctx).Warn(errMSG) return nil, status.Error(codes.InvalidArgument, errMSG) } @@ -692,10 +719,10 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.Encr // peer's under the same account of any updates. func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { realIP := getRealIP(ctx) - log.Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String()) + log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String()) syncMetaReq := &proto.SyncMetaRequest{} - peerKey, err := s.parseRequest(req, syncMetaReq) + peerKey, err := s.parseRequest(ctx, req, syncMetaReq) if err != nil { return nil, err } @@ -703,20 +730,20 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) if syncMetaReq.GetMeta() == nil { msg := status.Errorf(codes.FailedPrecondition, "peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) - log.Warn(msg) + log.WithContext(ctx).Warn(msg) return nil, msg } - err = s.accountManager.SyncPeerMeta(peerKey.String(), extractPeerMeta(syncMetaReq.GetMeta())) + err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta())) if err != nil { - return nil, mapError(err) + return nil, mapError(ctx, err) } return &proto.Empty{}, nil } // toProtocolChecks converts posture checks to protocol checks. -func toProtocolChecks(postureChecks []*posture.Checks) []*proto.Checks { +func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks { protoChecks := make([]*proto.Checks, 0, len(postureChecks)) for _, postureCheck := range postureChecks { protoChecks = append(protoChecks, toProtocolCheck(postureCheck)) diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index d3c9954d3..ffa5b9a28 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -35,34 +35,34 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) * // GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } if !(user.HasAdminPower() || user.IsServiceUser) { - util.WriteError(status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w) + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w) return } resp := toAccountResponse(account) - util.WriteJSONObject(w, []*api.Account{resp}) + util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } // UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - _, user, err := h.accountManager.GetAccountFromToken(claims) + _, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) accountID := vars["accountId"] if len(accountID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid accountID ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid accountID ID"), w) return } @@ -96,15 +96,15 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) settings.JWTAllowGroups = *req.Settings.JwtAllowGroups } - updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings) + updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } resp := toAccountResponse(updatedAccount) - util.WriteJSONObject(w, &resp) + util.WriteJSONObject(r.Context(), w, &resp) } // DeleteAccount is a HTTP DELETE handler to delete an account @@ -118,17 +118,17 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) vars := mux.Vars(r) targetAccountID := vars["accountId"] if len(targetAccountID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid account ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid account ID"), w) return } - err := h.accountManager.DeleteAccount(targetAccountID, claims.UserId) + err := h.accountManager.DeleteAccount(r.Context(), targetAccountID, claims.UserId) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, emptyObject{}) } func toAccountResponse(account *server.Account) *api.Account { diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/accounts_handler_test.go index 9d174d0be..45c7679e5 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/accounts_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -22,10 +23,10 @@ import ( func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler { return &AccountsHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return account, admin, nil }, - UpdateAccountSettingsFunc: func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) { + UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") diff --git a/management/server/http/dns_settings_handler.go b/management/server/http/dns_settings_handler.go index baaf7ba69..74b0e1a55 100644 --- a/management/server/http/dns_settings_handler.go +++ b/management/server/http/dns_settings_handler.go @@ -32,16 +32,16 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg // GetDNSSettings returns the DNS settings for the account func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - log.Error(err) + log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - dnsSettings, err := h.accountManager.GetDNSSettings(account.Id, user.Id) + dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -49,15 +49,15 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque DisabledManagementGroups: dnsSettings.DisabledManagementGroups, } - util.WriteJSONObject(w, apiDNSSettings) + util.WriteJSONObject(r.Context(), w, apiDNSSettings) } // UpdateDNSSettings handles update to DNS settings of an account func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -72,9 +72,9 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re DisabledManagementGroups: req.DisabledManagementGroups, } - err = h.accountManager.SaveDNSSettings(account.Id, user.Id, updateDNSSettings) + err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -82,5 +82,5 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re DisabledManagementGroups: updateDNSSettings.DisabledManagementGroups, } - util.WriteJSONObject(w, &resp) + util.WriteJSONObject(r.Context(), w, &resp) } diff --git a/management/server/http/dns_settings_handler_test.go b/management/server/http/dns_settings_handler_test.go index a2f65a521..897ae63dc 100644 --- a/management/server/http/dns_settings_handler_test.go +++ b/management/server/http/dns_settings_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -42,16 +43,16 @@ var testingDNSSettingsAccount = &server.Account{ func initDNSSettingsTestData() *DNSSettingsHandler { return &DNSSettingsHandler{ accountManager: &mock_server.MockAccountManager{ - GetDNSSettingsFunc: func(accountID string, userID string) (*server.DNSSettings, error) { + GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) { return &testingDNSSettingsAccount.DNSSettings, nil }, - SaveDNSSettingsFunc: func(accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { + SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { if dnsSettingsToSave != nil { return nil } return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") }, - GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil }, }, diff --git a/management/server/http/events_handler.go b/management/server/http/events_handler.go index a89c206a3..428b4c164 100644 --- a/management/server/http/events_handler.go +++ b/management/server/http/events_handler.go @@ -1,6 +1,7 @@ package http import ( + "context" "fmt" "net/http" @@ -33,16 +34,16 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev // GetAllEvents list of the given account func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - log.Error(err) + log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - accountEvents, err := h.accountManager.GetEvents(account.Id, user.Id) + accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } events := make([]*api.Event, len(accountEvents)) @@ -50,20 +51,20 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { events[i] = toEventResponse(e) } - err = h.fillEventsWithUserInfo(events, account.Id, user.Id) + err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, events) + util.WriteJSONObject(r.Context(), w, events) } -func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, userId string) error { +func (h *EventsHandler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error { // build email, name maps based on users - userInfos, err := h.accountManager.GetUsersFromAccount(accountId, userId) + userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId) if err != nil { - log.Errorf("failed to get users from account: %s", err) + log.WithContext(ctx).Errorf("failed to get users from account: %s", err) return err } @@ -80,7 +81,7 @@ func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, u if event.InitiatorEmail == "" { event.InitiatorEmail, ok = emails[event.InitiatorId] if !ok { - log.Warnf("failed to resolve email for initiator: %s", event.InitiatorId) + log.WithContext(ctx).Warnf("failed to resolve email for initiator: %s", event.InitiatorId) } } diff --git a/management/server/http/events_handler_test.go b/management/server/http/events_handler_test.go index 4cfad922b..8bdd508bf 100644 --- a/management/server/http/events_handler_test.go +++ b/management/server/http/events_handler_test.go @@ -1,6 +1,7 @@ package http import ( + "context" "encoding/json" "io" "net/http" @@ -22,13 +23,13 @@ import ( func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler { return &EventsHandler{ accountManager: &mock_server.MockAccountManager{ - GetEventsFunc: func(accountID, userID string) ([]*activity.Event, error) { + GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) { if accountID == account { return events, nil } return []*activity.Event{}, nil }, - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return &server.Account{ Id: claims.AccountId, Domain: "hotmail.com", @@ -37,7 +38,7 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E }, }, user, nil }, - GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) { + GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { return make([]*server.UserInfo, 0), nil }, }, diff --git a/management/server/http/geolocation_handler_test.go b/management/server/http/geolocation_handler_test.go index 226711002..b8247f78d 100644 --- a/management/server/http/geolocation_handler_test.go +++ b/management/server/http/geolocation_handler_test.go @@ -1,6 +1,7 @@ package http import ( + "context" "encoding/json" "io" "net/http" @@ -35,13 +36,13 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler { err = util.CopyFileContents(geonamesDBPath, path.Join(tempDir, geolocation.GeoSqliteDBFile)) assert.NoError(t, err) - geo, err := geolocation.NewGeolocation(tempDir) + geo, err := geolocation.NewGeolocation(context.Background(), tempDir) assert.NoError(t, err) t.Cleanup(func() { _ = geo.Stop() }) return &GeolocationsHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { user := server.NewAdminUser("test_user") return &server.Account{ Id: claims.AccountId, diff --git a/management/server/http/geolocations_handler.go b/management/server/http/geolocations_handler.go index cf961267b..af4d3116f 100644 --- a/management/server/http/geolocations_handler.go +++ b/management/server/http/geolocations_handler.go @@ -40,19 +40,19 @@ func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geoloca // GetAllCountries retrieves a list of all countries func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Request) { if err := l.authenticateUser(r); err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } if l.geolocationManager == nil { // TODO: update error message to include geo db self hosted doc link when ready - util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w) + util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w) return } allCountries, err := l.geolocationManager.GetAllCountries() if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -60,32 +60,32 @@ func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Req for _, country := range allCountries { countries = append(countries, toCountryResponse(country)) } - util.WriteJSONObject(w, countries) + util.WriteJSONObject(r.Context(), w, countries) } // GetCitiesByCountry retrieves a list of cities based on the given country code func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.Request) { if err := l.authenticateUser(r); err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) countryCode := vars["country"] if !countryCodeRegex.MatchString(countryCode) { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid country code"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid country code"), w) return } if l.geolocationManager == nil { - util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+ + util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+ "Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w) return } allCities, err := l.geolocationManager.GetCitiesByCountry(countryCode) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -93,12 +93,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http. for _, city := range allCities { cities = append(cities, toCityResponse(city)) } - util.WriteJSONObject(w, cities) + util.WriteJSONObject(r.Context(), w, cities) } func (l *GeolocationsHandler) authenticateUser(r *http.Request) error { claims := l.claimsExtractor.FromRequestContext(r) - _, user, err := l.accountManager.GetAccountFromToken(claims) + _, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { return err } diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index 47bcf2f32..c622d873a 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -35,16 +35,16 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr // GetAllGroups list for the account func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - log.Error(err) + log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - groups, err := h.accountManager.GetAllGroups(account.Id, user.Id) + groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -53,42 +53,42 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { groupsResponse = append(groupsResponse, toGroupResponse(account, group)) } - util.WriteJSONObject(w, groupsResponse) + util.WriteJSONObject(r.Context(), w, groupsResponse) } // UpdateGroup handles update to a group identified by a given ID func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) groupID, ok := vars["groupId"] if !ok { - util.WriteError(status.Errorf(status.InvalidArgument, "group ID field is missing"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group ID field is missing"), w) return } if len(groupID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "group ID can't be empty"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group ID can't be empty"), w) return } eg, ok := account.Groups[groupID] if !ok { - util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w) + util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w) return } allGroup, err := account.GetGroupAll() if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } if allGroup.ID == groupID { - util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w) return } @@ -100,7 +100,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { } if req.Name == "" { - util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w) return } @@ -118,21 +118,21 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { IntegrationReference: eg.IntegrationReference, } - if err := h.accountManager.SaveGroup(account.Id, user.Id, &group); err != nil { - log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err) - util.WriteError(err, w) + if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil { + log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, toGroupResponse(account, &group)) + util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group)) } // CreateGroup handles group creation request func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -144,7 +144,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { } if req.Name == "" { - util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w) return } @@ -160,62 +160,62 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { Issued: nbgroup.GroupIssuedAPI, } - err = h.accountManager.SaveGroup(account.Id, user.Id, &group) + err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, toGroupResponse(account, &group)) + util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group)) } // DeleteGroup handles group deletion request func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } aID := account.Id groupID := mux.Vars(r)["groupId"] if len(groupID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) return } allGroup, err := account.GetGroupAll() if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } if allGroup.ID == groupID { - util.WriteError(status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w) return } - err = h.accountManager.DeleteGroup(aID, user.Id, groupID) + err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID) if err != nil { _, ok := err.(*server.GroupLinkError) if ok { util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w) return } - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, emptyObject{}) } // GetGroup returns a group func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -223,19 +223,19 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { case http.MethodGet: groupID := mux.Vars(r)["groupId"] if len(groupID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) return } - group, err := h.accountManager.GetGroup(account.Id, groupID, user.Id) + group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, toGroupResponse(account, group)) + util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group)) default: - util.WriteError(status.Errorf(status.NotFound, "HTTP method not found"), w) + util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w) return } } diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index 3d74b848c..d5ed07c9e 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -32,13 +33,13 @@ var TestPeers = map[string]*nbpeer.Peer{ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { return &GroupsHandler{ accountManager: &mock_server.MockAccountManager{ - SaveGroupFunc: func(accountID, userID string, group *nbgroup.Group) error { + SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error { if !strings.HasPrefix(group.ID, "id-") { group.ID = "id-was-set" } return nil }, - GetGroupFunc: func(_, groupID, _ string) (*nbgroup.Group, error) { + GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) { if groupID != "idofthegroup" { return nil, status.Errorf(status.NotFound, "not found") } @@ -55,7 +56,7 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { Issued: nbgroup.GroupIssuedAPI, }, nil }, - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return &server.Account{ Id: claims.AccountId, Domain: "hotmail.com", @@ -70,7 +71,7 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { }, }, user, nil }, - DeleteGroupFunc: func(accountID, userId, groupID string) error { + DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error { if groupID == "linked-grp" { return &server.GroupLinkError{ Resource: "something", diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 4405d295c..3fe26d0ce 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -9,6 +9,7 @@ import ( "github.com/rs/cors" "github.com/netbirdio/management-integrations/integrations" + s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/middleware" @@ -57,6 +58,11 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa corsMiddleware := cors.AllowAll() + claimsExtractor = jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ) + acMiddleware := middleware.NewAccessControl( authCfg.Audience, authCfg.UserIDClaim, diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index de386f173..0ad250f43 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "net/http" "regexp" @@ -15,7 +16,7 @@ import ( ) // GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims -type GetUser func(claims jwtclaims.AuthorizationClaims) (*server.User, error) +type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only type AccessControl struct { @@ -46,15 +47,15 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { claims := a.claimsExtract.FromRequestContext(r) - user, err := a.getUser(claims) + user, err := a.getUser(r.Context(), claims) if err != nil { - log.Errorf("failed to get user from claims: %s", err) - util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w) + log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err) + util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w) return } if user.IsBlocked() { - util.WriteError(status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w) + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w) return } @@ -63,12 +64,12 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut: if tokenPathRegexp.MatchString(r.URL.Path) { - log.Debugf("valid Path") + log.WithContext(r.Context()).Debugf("valid Path") h.ServeHTTP(w, r) return } - util.WriteError(status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w) + util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w) return } } diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 204c9f4eb..b25aad99c 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" + nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -19,16 +20,16 @@ import ( ) // GetAccountFromPATFunc function -type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) +type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) // ValidateAndParseTokenFunc function -type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error) +type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error) // MarkPATUsedFunc function -type MarkPATUsedFunc func(token string) error +type MarkPATUsedFunc func(ctx context.Context, token string) error // CheckUserAccessByJWTGroupsFunc function -type CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error +type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens type AuthMiddleware struct { @@ -85,23 +86,27 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { case "bearer": err := m.checkJWTFromRequest(w, r, auth) if err != nil { - log.Errorf("Error when validating JWT claims: %s", err.Error()) - util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w) + log.WithContext(r.Context()).Errorf("Error when validating JWT claims: %s", err.Error()) + util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) return } - h.ServeHTTP(w, r) case "token": err := m.checkPATFromRequest(w, r, auth) if err != nil { - log.Debugf("Error when validating PAT claims: %s", err.Error()) - util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w) + log.WithContext(r.Context()).Debugf("Error when validating PAT claims: %s", err.Error()) + util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) return } - h.ServeHTTP(w, r) default: - util.WriteError(status.Errorf(status.Unauthorized, "no valid authentication provided"), w) + util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w) return } + claims := m.claimsExtractor.FromRequestContext(r) + //nolint + ctx := context.WithValue(r.Context(), nbContext.UserIDKey, claims.UserId) + //nolint + ctx = context.WithValue(ctx, nbContext.AccountIDKey, claims.AccountId) + h.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -114,7 +119,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ return fmt.Errorf("Error extracting token: %w", err) } - validatedToken, err := m.validateAndParseToken(token) + validatedToken, err := m.validateAndParseToken(r.Context(), token) if err != nil { return err } @@ -123,7 +128,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ return nil } - if err := m.verifyUserAccess(validatedToken); err != nil { + if err := m.verifyUserAccess(r.Context(), validatedToken); err != nil { return err } @@ -138,9 +143,9 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ // verifyUserAccess checks if a user, based on a validated JWT token, // is allowed access, particularly in cases where the admin enabled JWT // group propagation and designated certain groups with access permissions. -func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error { +func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *jwt.Token) error { authClaims := m.claimsExtractor.FromToken(validatedToken) - return m.checkUserAccessByJWTGroups(authClaims) + return m.checkUserAccessByJWTGroups(ctx, authClaims) } // CheckPATFromRequest checks if the PAT is valid @@ -152,7 +157,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ return fmt.Errorf("Error extracting token: %w", err) } - account, user, pat, err := m.getAccountFromPAT(token) + account, user, pat, err := m.getAccountFromPAT(r.Context(), token) if err != nil { return fmt.Errorf("invalid Token: %w", err) } @@ -160,7 +165,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ return fmt.Errorf("token expired") } - err = m.markPATUsed(pat.ID) + err = m.markPATUsed(r.Context(), pat.ID) if err != nil { return err } diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 588bcaf02..fdfb0ea24 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -15,15 +16,16 @@ import ( ) const ( - audience = "audience" - userIDClaim = "userIDClaim" - accountID = "accountID" - domain = "domain" - userID = "userID" - tokenID = "tokenID" - PAT = "nbp_PAT" - JWT = "JWT" - wrongToken = "wrongToken" + audience = "audience" + userIDClaim = "userIDClaim" + accountID = "accountID" + domain = "domain" + domainCategory = "domainCategory" + userID = "userID" + tokenID = "tokenID" + PAT = "nbp_PAT" + JWT = "JWT" + wrongToken = "wrongToken" ) var testAccount = &server.Account{ @@ -47,14 +49,14 @@ var testAccount = &server.Account{ }, } -func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { +func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { if token == PAT { return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil } return nil, nil, nil, fmt.Errorf("PAT invalid") } -func mockValidateAndParseToken(token string) (*jwt.Token, error) { +func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) { if token == JWT { return &jwt.Token{ Claims: jwt.MapClaims{ @@ -67,14 +69,14 @@ func mockValidateAndParseToken(token string) (*jwt.Token, error) { return nil, fmt.Errorf("JWT invalid") } -func mockMarkPATUsed(token string) error { +func mockMarkPATUsed(_ context.Context, token string) error { if token == tokenID { return nil } return fmt.Errorf("Should never get reached") } -func mockCheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error { +func mockCheckUserAccessByJWTGroups(_ context.Context, claims jwtclaims.AuthorizationClaims) error { if testAccount.Id != claims.AccountId { return fmt.Errorf("account with id %s does not exist", claims.AccountId) } diff --git a/management/server/http/middleware/bypass/bypass.go b/management/server/http/middleware/bypass/bypass.go index 87b41c6fc..9447704cb 100644 --- a/management/server/http/middleware/bypass/bypass.go +++ b/management/server/http/middleware/bypass/bypass.go @@ -56,7 +56,7 @@ func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r * for bypassPath := range bypassPaths { matched, err := path.Match(bypassPath, requestPath) if err != nil { - log.Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err) + log.WithContext(r.Context()).Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err) continue } if matched { diff --git a/management/server/http/nameservers_handler.go b/management/server/http/nameservers_handler.go index 8d9f0d717..c6e00bb2d 100644 --- a/management/server/http/nameservers_handler.go +++ b/management/server/http/nameservers_handler.go @@ -36,16 +36,16 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg // GetAllNameservers returns the list of nameserver groups for the account func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - log.Error(err) + log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - nsGroups, err := h.accountManager.ListNameServerGroups(account.Id, user.Id) + nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -54,15 +54,15 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re apiNameservers = append(apiNameservers, toNameserverGroupResponse(r)) } - util.WriteJSONObject(w, apiNameservers) + util.WriteJSONObject(r.Context(), w, apiNameservers) } // CreateNameserverGroup handles nameserver group creation request func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -75,33 +75,33 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt nsList, err := toServerNSList(req.Nameservers) if err != nil { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid NS servers format"), w) return } - nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled) + nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } resp := toNameserverGroupResponse(nsGroup) - util.WriteJSONObject(w, &resp) + util.WriteJSONObject(r.Context(), w, &resp) } // UpdateNameserverGroup handles update to a nameserver group identified by a given ID func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } nsGroupID := mux.Vars(r)["nsgroupId"] if len(nsGroupID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) return } @@ -114,7 +114,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt nsList, err := toServerNSList(req.Nameservers) if err != nil { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid NS servers format"), w) return } @@ -130,66 +130,66 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt SearchDomainsEnabled: req.SearchDomainsEnabled, } - err = h.accountManager.SaveNameServerGroup(account.Id, user.Id, updatedNSGroup) + err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } resp := toNameserverGroupResponse(updatedNSGroup) - util.WriteJSONObject(w, &resp) + util.WriteJSONObject(r.Context(), w, &resp) } // DeleteNameserverGroup handles nameserver group deletion request func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } nsGroupID := mux.Vars(r)["nsgroupId"] if len(nsGroupID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) return } - err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID, user.Id) + err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, emptyObject{}) } // GetNameserverGroup handles a nameserver group Get request identified by ID func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - log.Error(err) + log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } nsGroupID := mux.Vars(r)["nsgroupId"] if len(nsGroupID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) return } - nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, user.Id, nsGroupID) + nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } resp := toNameserverGroupResponse(nsGroup) - util.WriteJSONObject(w, &resp) + util.WriteJSONObject(r.Context(), w, &resp) } func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) { diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/nameservers_handler_test.go index e1fabb198..28b080571 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/nameservers_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -61,13 +62,13 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{ func initNameserversTestData() *NameserversHandler { return &NameserversHandler{ accountManager: &mock_server.MockAccountManager{ - GetNameServerGroupFunc: func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { + GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { if nsGroupID == existingNSGroupID { return baseExistingNSGroup.Copy(), nil } return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID) }, - CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) { + CreateNameServerGroupFunc: func(_ context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) { return &nbdns.NameServerGroup{ ID: existingNSGroupID, Name: name, @@ -80,16 +81,16 @@ func initNameserversTestData() *NameserversHandler { SearchDomainsEnabled: searchDomains, }, nil }, - DeleteNameServerGroupFunc: func(accountID, nsGroupID, _ string) error { + DeleteNameServerGroupFunc: func(_ context.Context, accountID, nsGroupID, _ string) error { return nil }, - SaveNameServerGroupFunc: func(accountID, _ string, nsGroupToSave *nbdns.NameServerGroup) error { + SaveNameServerGroupFunc: func(_ context.Context, accountID, _ string, nsGroupToSave *nbdns.NameServerGroup) error { if nsGroupToSave.ID == existingNSGroupID { return nil } return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) }, - GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return testingNSAccount, testingAccount.Users["test_user"], nil }, }, diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go index d2398a7e1..9d8448d3d 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/pat_handler.go @@ -34,22 +34,22 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH // GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) userID := vars["userId"] if len(userID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - pats, err := h.accountManager.GetAllPATs(account.Id, user.Id, userID) + pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -58,53 +58,53 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { patResponse = append(patResponse, toPATResponse(pat)) } - util.WriteJSONObject(w, patResponse) + util.WriteJSONObject(r.Context(), w, patResponse) } // GetToken is HTTP GET handler that returns a personal access token for the given user func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } tokenID := vars["tokenId"] if len(tokenID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid token ID"), w) return } - pat, err := h.accountManager.GetPAT(account.Id, user.Id, targetUserID, tokenID) + pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, toPATResponse(pat)) + util.WriteJSONObject(r.Context(), w, toPATResponse(pat)) } // CreateToken is HTTP POST handler that creates a personal access token for the given user func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } @@ -115,44 +115,44 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { return } - pat, err := h.accountManager.CreatePAT(account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn) + pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, toPATGeneratedResponse(pat)) + util.WriteJSONObject(r.Context(), w, toPATGeneratedResponse(pat)) } // DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } tokenID := vars["tokenId"] if len(tokenID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid token ID"), w) return } - err = h.accountManager.DeletePAT(account.Id, user.Id, targetUserID, tokenID) + err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, emptyObject{}) } func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go index 5058b4110..b72f71468 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/pat_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -63,7 +64,7 @@ var testAccount = &server.Account{ func initPATTestData() *PATHandler { return &PATHandler{ accountManager: &mock_server.MockAccountManager{ - CreatePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { + CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } @@ -76,10 +77,10 @@ func initPATTestData() *PATHandler { }, nil }, - GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return testAccount, testAccount.Users[existingUserID], nil }, - DeletePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) error { + DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { if accountID != existingAccountID { return status.Errorf(status.NotFound, "account with ID %s not found", accountID) } @@ -91,7 +92,7 @@ func initPATTestData() *PATHandler { } return nil }, - GetPATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { + GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } @@ -103,7 +104,7 @@ func initPATTestData() *PATHandler { } return testAccount.Users[existingUserID].PATs[existingTokenID], nil }, - GetAllPATsFunc: func(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { + GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 762576506..1fb18669c 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -1,6 +1,7 @@ package http import ( + "context" "encoding/json" "fmt" "net/http" @@ -47,16 +48,16 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) return peerToReturn, nil } -func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w http.ResponseWriter) { - peer, err := h.accountManager.GetPeer(account.Id, peerID, userID) +func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { + peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID) if err != nil { - util.WriteError(err, w) + util.WriteError(ctx, err, w) return } peerToReturn, err := h.checkPeerStatus(peer) if err != nil { - util.WriteError(err, w) + util.WriteError(ctx, err, w) return } dnsDomain := h.accountManager.GetDNSDomain() @@ -65,19 +66,19 @@ func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w validPeers, err := h.accountManager.GetValidatedPeers(account) if err != nil { - log.Errorf("failed to list appreoved peers: %v", err) - util.WriteError(fmt.Errorf("internal error"), w) + log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) + util.WriteError(ctx, fmt.Errorf("internal error"), w) return } - netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers) + netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers) accessiblePeers := toAccessiblePeers(netMap, dnsDomain) _, valid := validPeers[peer.ID] - util.WriteJSONObject(w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid)) + util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid)) } -func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -99,9 +100,9 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe } } - peer, err := h.accountManager.UpdatePeer(account.Id, user.Id, update) + peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update) if err != nil { - util.WriteError(err, w) + util.WriteError(ctx, err, w) return } dnsDomain := h.accountManager.GetDNSDomain() @@ -110,75 +111,75 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe validPeers, err := h.accountManager.GetValidatedPeers(account) if err != nil { - log.Errorf("failed to list appreoved peers: %v", err) - util.WriteError(fmt.Errorf("internal error"), w) + log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) + util.WriteError(ctx, fmt.Errorf("internal error"), w) return } - netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers) + netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers) accessiblePeers := toAccessiblePeers(netMap, dnsDomain) _, valid := validPeers[peer.ID] - util.WriteJSONObject(w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid)) + util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid)) } -func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w http.ResponseWriter) { - err := h.accountManager.DeletePeer(accountID, peerID, userID) +func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { + err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID) if err != nil { - log.Errorf("failed to delete peer: %v", err) - util.WriteError(err, w) + log.WithContext(ctx).Errorf("failed to delete peer: %v", err) + util.WriteError(ctx, err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(ctx, w, emptyObject{}) } // HandlePeer handles all peer requests for GET, PUT and DELETE operations func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) peerID := vars["peerId"] if len(peerID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid peer ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w) return } switch r.Method { case http.MethodDelete: - h.deletePeer(account.Id, user.Id, peerID, w) + h.deletePeer(r.Context(), account.Id, user.Id, peerID, w) return case http.MethodPut: - h.updatePeer(account, user, peerID, w, r) + h.updatePeer(r.Context(), account, user, peerID, w, r) return case http.MethodGet: - h.getPeer(account, peerID, user.Id, w) + h.getPeer(r.Context(), account, peerID, user.Id, w) return default: - util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w) + util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) } } // GetAllPeers returns a list of all peers associated with a provided account func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { - util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w) + util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) return } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - peers, err := h.accountManager.GetPeers(account.Id, user.Id) + peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -188,34 +189,34 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { for _, peer := range peers { peerToReturn, err := h.checkPeerStatus(peer) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) - accessiblePeerNumbers, _ := h.accessiblePeersNumber(account, peer.ID) + accessiblePeerNumbers, _ := h.accessiblePeersNumber(r.Context(), account, peer.ID) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers)) } validPeersMap, err := h.accountManager.GetValidatedPeers(account) if err != nil { - log.Errorf("failed to list appreoved peers: %v", err) - util.WriteError(fmt.Errorf("internal error"), w) + log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) + util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } h.setApprovalRequiredFlag(respBody, validPeersMap) - util.WriteJSONObject(w, respBody) + util.WriteJSONObject(r.Context(), w, respBody) } -func (h *PeersHandler) accessiblePeersNumber(account *server.Account, peerID string) (int, error) { +func (h *PeersHandler) accessiblePeersNumber(ctx context.Context, account *server.Account, peerID string) (int, error) { validatedPeersMap, err := h.accountManager.GetValidatedPeers(account) if err != nil { return 0, err } - netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validatedPeersMap) + netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validatedPeersMap) return len(netMap.Peers) + len(netMap.OfflinePeers), nil } diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index 53df5cb00..153c8f03a 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "io" "net" @@ -29,7 +30,7 @@ const noUpdateChannelTestPeerID = "no-update-channel" func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { return &PeersHandler{ accountManager: &mock_server.MockAccountManager{ - UpdatePeerFunc: func(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { + UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { var p *nbpeer.Peer for _, peer := range peers { if update.ID == peer.ID { @@ -42,7 +43,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { p.Name = update.Name return p, nil }, - GetPeerFunc: func(accountID, peerID, userID string) (*nbpeer.Peer, error) { + GetPeerFunc: func(_ context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { var p *nbpeer.Peer for _, peer := range peers { if peerID == peer.ID { @@ -52,13 +53,13 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { } return p, nil }, - GetPeersFunc: func(accountID, userID string) ([]*nbpeer.Peer, error) { + GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { return peers, nil }, GetDNSDomainFunc: func() string { return "netbird.selfhosted" }, - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { user := server.NewAdminUser("test_user") return &server.Account{ Id: claims.AccountId, diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index e163e63b9..9622668f4 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -35,15 +35,15 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) * // GetAllPolicies list for the account func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - accountPolicies, err := h.accountManager.ListPolicies(account.Id, user.Id) + accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -51,28 +51,28 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { for _, policy := range accountPolicies { resp := toPolicyResponse(account, policy) if len(resp.Rules) == 0 { - util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w) + util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return } policies = append(policies, resp) } - util.WriteJSONObject(w, policies) + util.WriteJSONObject(r.Context(), w, policies) } // UpdatePolicy handles update to a policy identified by a given ID func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) policyID := vars["policyId"] if len(policyID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w) return } @@ -84,7 +84,7 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { } } if policyIdx < 0 { - util.WriteError(status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w) + util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w) return } @@ -94,9 +94,9 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { // CreatePolicy handles policy creation request func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -118,12 +118,12 @@ func (h *Policies) savePolicy( } if req.Name == "" { - util.WriteError(status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w) return } if len(req.Rules) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "policy rules shouldn't be empty"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy rules shouldn't be empty"), w) return } @@ -137,31 +137,31 @@ func (h *Policies) savePolicy( Enabled: req.Enabled, Description: req.Description, } - for _, r := range req.Rules { + for _, rule := range req.Rules { pr := server.PolicyRule{ - ID: policyID, //TODO: when policy can contain multiple rules, need refactor - Name: r.Name, - Destinations: groupMinimumsToStrings(account, r.Destinations), - Sources: groupMinimumsToStrings(account, r.Sources), - Bidirectional: r.Bidirectional, + ID: policyID, // TODO: when policy can contain multiple rules, need refactor + Name: rule.Name, + Destinations: groupMinimumsToStrings(account, rule.Destinations), + Sources: groupMinimumsToStrings(account, rule.Sources), + Bidirectional: rule.Bidirectional, } - pr.Enabled = r.Enabled - if r.Description != nil { - pr.Description = *r.Description + pr.Enabled = rule.Enabled + if rule.Description != nil { + pr.Description = *rule.Description } - switch r.Action { + switch rule.Action { case api.PolicyRuleUpdateActionAccept: pr.Action = server.PolicyTrafficActionAccept case api.PolicyRuleUpdateActionDrop: pr.Action = server.PolicyTrafficActionDrop default: - util.WriteError(status.Errorf(status.InvalidArgument, "unknown action type"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown action type"), w) return } - switch r.Protocol { + switch rule.Protocol { case api.PolicyRuleUpdateProtocolAll: pr.Protocol = server.PolicyRuleProtocolALL case api.PolicyRuleUpdateProtocolTcp: @@ -171,14 +171,14 @@ func (h *Policies) savePolicy( case api.PolicyRuleUpdateProtocolIcmp: pr.Protocol = server.PolicyRuleProtocolICMP default: - util.WriteError(status.Errorf(status.InvalidArgument, "unknown protocol type: %v", r.Protocol), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w) return } - if r.Ports != nil && len(*r.Ports) != 0 { - for _, v := range *r.Ports { + if rule.Ports != nil && len(*rule.Ports) != 0 { + for _, v := range *rule.Ports { if port, err := strconv.Atoi(v); err != nil || port < 1 || port > 65535 { - util.WriteError(status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w) return } pr.Ports = append(pr.Ports, v) @@ -189,16 +189,16 @@ func (h *Policies) savePolicy( switch pr.Protocol { case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP: if len(pr.Ports) != 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w) return } if !pr.Bidirectional { - util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) return } case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP: if !pr.Bidirectional && len(pr.Ports) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) return } } @@ -210,26 +210,26 @@ func (h *Policies) savePolicy( policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks) } - if err := h.accountManager.SavePolicy(account.Id, user.Id, &policy); err != nil { - util.WriteError(err, w) + if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil { + util.WriteError(r.Context(), err, w) return } resp := toPolicyResponse(account, &policy) if len(resp.Rules) == 0 { - util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w) + util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return } - util.WriteJSONObject(w, resp) + util.WriteJSONObject(r.Context(), w, resp) } // DeletePolicy handles policy deletion request func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } aID := account.Id @@ -237,24 +237,24 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) policyID := vars["policyId"] if len(policyID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w) return } - if err = h.accountManager.DeletePolicy(aID, policyID, user.Id); err != nil { - util.WriteError(err, w) + if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil { + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, emptyObject{}) } // GetPolicy handles a group Get request identified by ID func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -263,25 +263,25 @@ func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) policyID := vars["policyId"] if len(policyID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w) return } - policy, err := h.accountManager.GetPolicy(account.Id, policyID, user.Id) + policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } resp := toPolicyResponse(account, policy) if len(resp.Rules) == 0 { - util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w) + util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return } - util.WriteJSONObject(w, resp) + util.WriteJSONObject(r.Context(), w, resp) default: - util.WriteError(status.Errorf(status.NotFound, "method not found"), w) + util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w) } } diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index 74e682854..06274fb07 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -30,21 +31,21 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { } return &Policies{ accountManager: &mock_server.MockAccountManager{ - GetPolicyFunc: func(_, policyID, _ string) (*server.Policy, error) { + GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) { policy, ok := testPolicies[policyID] if !ok { return nil, status.Errorf(status.NotFound, "policy not found") } return policy, nil }, - SavePolicyFunc: func(_, _ string, policy *server.Policy) error { + SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error { if !strings.HasPrefix(policy.ID, "id-") { policy.ID = "id-was-set" policy.Rules[0].ID = "id-was-set" } return nil }, - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { user := server.NewAdminUser("test_user") return &server.Account{ Id: claims.AccountId, diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go index 9051e8d18..059cb3b80 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/posture_checks_handler.go @@ -37,15 +37,15 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa // GetAllPostureChecks list for the account func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(claims) + account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - accountPostureChecks, err := p.accountManager.ListPostureChecks(account.Id, user.Id) + accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -54,22 +54,22 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt postureChecks = append(postureChecks, postureCheck.ToAPIResponse()) } - util.WriteJSONObject(w, postureChecks) + util.WriteJSONObject(r.Context(), w, postureChecks) } // UpdatePostureCheck handles update to a posture check identified by a given ID func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(claims) + account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) postureChecksID := vars["postureCheckId"] if len(postureChecksID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w) return } @@ -81,7 +81,7 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http } } if postureChecksIdx < 0 { - util.WriteError(status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w) + util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w) return } @@ -91,9 +91,9 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http // CreatePostureCheck handles posture check creation request func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(claims) + account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -103,50 +103,50 @@ func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http // GetPostureCheck handles a posture check Get request identified by ID func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(claims) + account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) postureChecksID := vars["postureCheckId"] if len(postureChecksID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w) return } - postureChecks, err := p.accountManager.GetPostureChecks(account.Id, postureChecksID, user.Id) + postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, postureChecks.ToAPIResponse()) + util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse()) } // DeletePostureCheck handles posture check deletion request func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(claims) + account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) postureChecksID := vars["postureCheckId"] if len(postureChecksID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w) return } - if err = p.accountManager.DeletePostureChecks(account.Id, postureChecksID, user.Id); err != nil { - util.WriteError(err, w) + if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil { + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, emptyObject{}) } // savePostureChecks handles posture checks create and update @@ -169,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks( if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil { if p.geolocationManager == nil { - util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+ + util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+ "Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w) return } @@ -177,14 +177,14 @@ func (p *PostureChecksHandler) savePostureChecks( postureChecks, err := posture.NewChecksFromAPIPostureCheckUpdate(req, postureChecksID) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - if err := p.accountManager.SavePostureChecks(account.Id, user.Id, postureChecks); err != nil { - util.WriteError(err, w) + if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil { + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, postureChecks.ToAPIResponse()) + util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse()) } diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/posture_checks_handler_test.go index f0b503f1a..dcb6e4924 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/posture_checks_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -33,14 +34,14 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH return &PostureChecksHandler{ accountManager: &mock_server.MockAccountManager{ - GetPostureChecksFunc: func(accountID, postureChecksID, userID string) (*posture.Checks, error) { + GetPostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { p, ok := testPostureChecks[postureChecksID] if !ok { return nil, status.Errorf(status.NotFound, "posture checks not found") } return p, nil }, - SavePostureChecksFunc: func(accountID, userID string, postureChecks *posture.Checks) error { + SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error { postureChecks.ID = "postureCheck" testPostureChecks[postureChecks.ID] = postureChecks @@ -50,7 +51,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH return nil }, - DeletePostureChecksFunc: func(accountID, postureChecksID, userID string) error { + DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error { _, ok := testPostureChecks[postureChecksID] if !ok { return status.Errorf(status.NotFound, "posture checks not found") @@ -59,14 +60,14 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH return nil }, - ListPostureChecksFunc: func(accountID, userID string) ([]*posture.Checks, error) { + ListPostureChecksFunc: func(_ context.Context, accountID, userID string) ([]*posture.Checks, error) { accountPostureChecks := make([]*posture.Checks, len(testPostureChecks)) for _, p := range testPostureChecks { accountPostureChecks = append(accountPostureChecks, p) } return accountPostureChecks, nil }, - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { user := server.NewAdminUser("test_user") return &server.Account{ Id: claims.AccountId, diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index a48c6d61d..18c347334 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -43,36 +43,36 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro // GetAllRoutes returns the list of routes for the account func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - routes, err := h.accountManager.ListRoutes(account.Id, user.Id) + routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } apiRoutes := make([]*api.Route, 0) - for _, r := range routes { - route, err := toRouteResponse(r) + for _, route := range routes { + route, err := toRouteResponse(route) if err != nil { - util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w) + util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w) return } apiRoutes = append(apiRoutes, route) } - util.WriteJSONObject(w, apiRoutes) + util.WriteJSONObject(r.Context(), w, apiRoutes) } // CreateRoute handles route creation request func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -84,7 +84,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { } if err := h.validateRoute(req); err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -94,7 +94,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { if req.Domains != nil { d, err := validateDomains(*req.Domains) if err != nil { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w) return } domains = d @@ -102,7 +102,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { } else if req.Network != nil { networkType, newPrefix, err = route.ParseNetwork(*req.Network) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } } @@ -120,24 +120,24 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { // Do not allow non-Linux peers if peer := account.GetPeer(peerId); peer != nil { if peer.Meta.GoOS != "linux" { - util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w) return } } - newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute) + newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } routes, err := toRouteResponse(newRoute) if err != nil { - util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w) + util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w) return } - util.WriteJSONObject(w, routes) + util.WriteJSONObject(r.Context(), w, routes) } func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) error { @@ -168,22 +168,22 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro // UpdateRoute handles update to a route identified by a given ID func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) routeID := vars["routeId"] if len(routeID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w) return } - _, err = h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id) + _, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -195,7 +195,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { } if err := h.validateRoute(req); err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -207,7 +207,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { // do not allow non Linux peers if peer := account.GetPeer(peerID); peer != nil { if peer.Meta.GoOS != "linux" { - util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w) return } } @@ -226,7 +226,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { if req.Domains != nil { d, err := validateDomains(*req.Domains) if err != nil { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w) return } newRoute.Domains = d @@ -234,7 +234,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { } else if req.Network != nil { newRoute.NetworkType, newRoute.Network, err = route.ParseNetwork(*req.Network) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } } @@ -247,73 +247,73 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { newRoute.PeerGroups = *req.PeerGroups } - err = h.accountManager.SaveRoute(account.Id, user.Id, newRoute) + err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } routes, err := toRouteResponse(newRoute) if err != nil { - util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w) + util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w) return } - util.WriteJSONObject(w, routes) + util.WriteJSONObject(r.Context(), w, routes) } // DeleteRoute handles route deletion request func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } routeID := mux.Vars(r)["routeId"] if len(routeID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w) return } - err = h.accountManager.DeleteRoute(account.Id, route.ID(routeID), user.Id) + err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, emptyObject{}) } // GetRoute handles a route Get request identified by ID func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } routeID := mux.Vars(r)["routeId"] if len(routeID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w) return } - foundRoute, err := h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id) + foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) if err != nil { - util.WriteError(status.Errorf(status.NotFound, "route not found"), w) + util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w) return } routes, err := toRouteResponse(foundRoute) if err != nil { - util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w) + util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w) return } - util.WriteJSONObject(w, routes) + util.WriteJSONObject(r.Context(), w, routes) } func toRouteResponse(serverRoute *route.Route) (*api.Route, error) { diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go index 261d0c231..40075eb9d 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/routes_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -89,7 +90,7 @@ var testingAccount = &server.Account{ func initRoutesTestData() *RoutesHandler { return &RoutesHandler{ accountManager: &mock_server.MockAccountManager{ - GetRouteFunc: func(_ string, routeID route.ID, _ string) (*route.Route, error) { + GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) { if routeID == existingRouteID { return baseExistingRoute, nil } @@ -104,7 +105,7 @@ func initRoutesTestData() *RoutesHandler { } return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) }, - CreateRouteFunc: func(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) { + CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) { if peerID == notFoundPeerID { return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } @@ -126,19 +127,19 @@ func initRoutesTestData() *RoutesHandler { KeepRoute: keepRoute, }, nil }, - SaveRouteFunc: func(_, _ string, r *route.Route) error { + SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error { if r.Peer == notFoundPeerID { return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer) } return nil }, - DeleteRouteFunc: func(_ string, routeID route.ID, _ string) error { + DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error { if routeID != existingRouteID { return status.Errorf(status.NotFound, "Peer with ID %s not found", routeID) } return nil }, - GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return testingAccount, testingAccount.Users["test_user"], nil }, }, diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go index 5faedea13..8ee7dfaba 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/setupkeys_handler.go @@ -1,6 +1,7 @@ package http import ( + "context" "encoding/json" "net/http" "time" @@ -34,9 +35,9 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) // CreateSetupKey is a POST requests that creates a new SetupKey func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -48,13 +49,13 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request } if req.Name == "" { - util.WriteError(status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w) return } if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable || server.SetupKeyType(req.Type) == server.SetupKeyOneOff) { - util.WriteError(status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w) return } @@ -63,7 +64,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request day := time.Hour * 24 year := day * 365 if expiresIn < day || expiresIn > year { - util.WriteError(status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w) return } @@ -75,54 +76,54 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request if req.Ephemeral != nil { ephemeral = *req.Ephemeral } - setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn, + setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn, req.AutoGroups, req.UsageLimit, user.Id, ephemeral) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - writeSuccess(w, setupKey) + writeSuccess(r.Context(), w, setupKey) } // GetSetupKey is a GET request to get a SetupKey by ID func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w) return } - key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID) + key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - writeSuccess(w, key) + writeSuccess(r.Context(), w, key) } // UpdateSetupKey is a PUT request to update server.SetupKey func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w) return } @@ -134,12 +135,12 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request } if req.Name == "" { - util.WriteError(status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w) return } if req.AutoGroups == nil { - util.WriteError(status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w) return } @@ -149,26 +150,26 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request newKey.Name = req.Name newKey.Id = keyID - newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey, user.Id) + newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - writeSuccess(w, newKey) + writeSuccess(r.Context(), w, newKey) } // GetAllSetupKeys is a GET request that returns a list of SetupKey func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id) + setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -177,15 +178,15 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques apiSetupKeys = append(apiSetupKeys, toResponseBody(key)) } - util.WriteJSONObject(w, apiSetupKeys) + util.WriteJSONObject(r.Context(), w, apiSetupKeys) } -func writeSuccess(w http.ResponseWriter, key *server.SetupKey) { +func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) err := json.NewEncoder(w).Encode(toResponseBody(key)) if err != nil { - util.WriteError(err, w) + util.WriteError(ctx, err, w) return } } diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go index ebbd5954f..bfa0ec008 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/setupkeys_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -33,7 +34,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup ) *SetupKeysHandler { return &SetupKeysHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return &server.Account{ Id: testAccountID, Domain: "hotmail.com", @@ -49,7 +50,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup }, }, user, nil }, - CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, + CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, _ int, _ string, ephemeral bool, ) (*server.SetupKey, error) { if keyName == newKey.Name || typ != newKey.Type { @@ -59,7 +60,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } return nil, fmt.Errorf("failed creating setup key") }, - GetSetupKeyFunc: func(accountID, userID, keyID string) (*server.SetupKey, error) { + GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*server.SetupKey, error) { switch keyID { case defaultKey.Id: return defaultKey, nil @@ -70,14 +71,14 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } }, - SaveSetupKeyFunc: func(accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) { + SaveSetupKeyFunc: func(_ context.Context, accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) { if key.Id == updatedSetupKey.Id { return updatedSetupKey, nil } return nil, status.Errorf(status.NotFound, "key %s not found", key.Id) }, - ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) { + ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) { return []*server.SetupKey{defaultKey}, nil }, }, diff --git a/management/server/http/users_handler.go b/management/server/http/users_handler.go index 531822668..2c2aed842 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/users_handler.go @@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) userID := vars["userId"] if len(userID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } existingUser, ok := account.Users[userID] if !ok { - util.WriteError(status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w) + util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w) return } @@ -74,11 +74,11 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { userRole := server.StrRoleToUserRole(req.Role) if userRole == server.UserRoleUnknown { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user role"), w) return } - newUser, err := h.accountManager.SaveUser(account.Id, user.Id, &server.User{ + newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{ Id: userID, Role: userRole, AutoGroups: req.AutoGroups, @@ -88,10 +88,10 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { }) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId)) + util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) } // DeleteUser is a DELETE request to delete a user @@ -102,26 +102,26 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - err = h.accountManager.DeleteUser(account.Id, user.Id, targetUserID) + err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, emptyObject{}) } // CreateUser creates a User in the system with a status "invited" (effectively this is a user invite). @@ -132,9 +132,9 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } @@ -146,7 +146,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { } if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown { - util.WriteError(status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w) return } @@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { name = *req.Name } - newUser, err := h.accountManager.CreateUser(account.Id, user.Id, &server.UserInfo{ + newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{ Email: email, Name: name, Role: req.Role, @@ -169,10 +169,10 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { Issued: server.UserIssuedAPI, }) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId)) + util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) } // GetAllUsers returns a list of users of the account this user belongs to. @@ -184,42 +184,42 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id) + data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } serviceUser := r.URL.Query().Get("service_user") users := make([]*api.User, 0) - for _, r := range data { - if r.NonDeletable { + for _, d := range data { + if d.NonDeletable { continue } if serviceUser == "" { - users = append(users, toUserResponse(r, claims.UserId)) + users = append(users, toUserResponse(d, claims.UserId)) continue } includeServiceUser, err := strconv.ParseBool(serviceUser) - log.Debugf("Should include service user: %v", includeServiceUser) + log.WithContext(r.Context()).Debugf("Should include service user: %v", includeServiceUser) if err != nil { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w) return } - if includeServiceUser == r.IsServiceUser { - users = append(users, toUserResponse(r, claims.UserId)) + if includeServiceUser == d.IsServiceUser { + users = append(users, toUserResponse(d, claims.UserId)) } } - util.WriteJSONObject(w, users) + util.WriteJSONObject(r.Context(), w, users) } // InviteUser resend invitations to users who haven't activated their accounts, @@ -231,26 +231,26 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { - util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - err = h.accountManager.InviteUser(account.Id, user.Id, targetUserID) + err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID) if err != nil { - util.WriteError(err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, emptyObject{}) } func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { diff --git a/management/server/http/users_handler_test.go b/management/server/http/users_handler_test.go index 8a78188be..a78ac3a4e 100644 --- a/management/server/http/users_handler_test.go +++ b/management/server/http/users_handler_test.go @@ -2,6 +2,7 @@ package http import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -63,10 +64,10 @@ var usersTestAccount = &server.Account{ func initUsersTestData() *UsersHandler { return &UsersHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return usersTestAccount, usersTestAccount.Users[claims.UserId], nil }, - GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) { + GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { users := make([]*server.UserInfo, 0) for _, v := range usersTestAccount.Users { users = append(users, &server.UserInfo{ @@ -81,13 +82,13 @@ func initUsersTestData() *UsersHandler { } return users, nil }, - CreateUserFunc: func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) { + CreateUserFunc: func(_ context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) { if userID != existingUserID { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID) } return key, nil }, - DeleteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error { + DeleteUserFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) error { if targetUserID == notFoundUserID { return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID) } @@ -96,7 +97,7 @@ func initUsersTestData() *UsersHandler { } return nil }, - SaveUserFunc: func(accountID, userID string, update *server.User) (*server.UserInfo, error) { + SaveUserFunc: func(_ context.Context, accountID, userID string, update *server.User) (*server.UserInfo, error) { if update.Id == notFoundUserID { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id) } @@ -111,7 +112,7 @@ func initUsersTestData() *UsersHandler { } return info, nil }, - InviteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error { + InviteUserFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) error { if initiatorUserID != existingUserID { return status.Errorf(status.NotFound, "user with ID %s does not exists", initiatorUserID) } diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index acaa2838c..603c1c696 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -1,6 +1,7 @@ package util import ( + "context" "encoding/json" "errors" "fmt" @@ -19,12 +20,12 @@ type ErrorResponse struct { } // WriteJSONObject simply writes object to the HTTP response in JSON format -func WriteJSONObject(w http.ResponseWriter, obj interface{}) { +func WriteJSONObject(ctx context.Context, w http.ResponseWriter, obj interface{}) { w.Header().Set("Content-Type", "application/json; charset=UTF-8") w.WriteHeader(http.StatusOK) err := json.NewEncoder(w).Encode(obj) if err != nil { - WriteError(err, w) + WriteError(ctx, err, w) return } } @@ -76,8 +77,8 @@ func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) { // WriteError converts an error to an JSON error response. // If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise -func WriteError(err error, w http.ResponseWriter) { - log.Errorf("got a handler error: %s", err.Error()) +func WriteError(ctx context.Context, err error, w http.ResponseWriter) { + log.WithContext(ctx).Errorf("got a handler error: %s", err.Error()) errStatus, ok := status.FromError(err) httpStatus := http.StatusInternalServerError msg := "internal server error" @@ -106,7 +107,7 @@ func WriteError(err error, w http.ResponseWriter) { msg = strings.ToLower(err.Error()) } else { unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error()) - log.Error(unhandledMSG) + log.WithContext(ctx).Error(unhandledMSG) } WriteErrorResponse(msg, httpStatus, w) diff --git a/management/server/idp/auth0.go b/management/server/idp/auth0.go index 34a5c0de5..497f1944f 100644 --- a/management/server/idp/auth0.go +++ b/management/server/idp/auth0.go @@ -183,7 +183,7 @@ func (c *Auth0Credentials) jwtStillValid() bool { } // requestJWTToken performs request to get jwt token -func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) { +func (c *Auth0Credentials) requestJWTToken(ctx context.Context) (*http.Response, error) { var res *http.Response reqURL := c.clientConfig.AuthIssuer + "/oauth/token" @@ -200,7 +200,7 @@ func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) { req.Header.Add("content-type", "application/json") - log.Debug("requesting new jwt token for idp manager") + log.WithContext(ctx).Debug("requesting new jwt token for idp manager") res, err = c.httpClient.Do(req) if err != nil { @@ -247,7 +247,7 @@ func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTTo } // Authenticate retrieves access token to use the Auth0 Management API -func (c *Auth0Credentials) Authenticate() (JWTToken, error) { +func (c *Auth0Credentials) Authenticate(ctx context.Context) (JWTToken, error) { c.mux.Lock() defer c.mux.Unlock() @@ -260,14 +260,14 @@ func (c *Auth0Credentials) Authenticate() (JWTToken, error) { return c.jwtToken, nil } - res, err := c.requestJWTToken() + res, err := c.requestJWTToken(ctx) if err != nil { return c.jwtToken, err } defer func() { err = res.Body.Close() if err != nil { - log.Errorf("error while closing get jwt token response body: %v", err) + log.WithContext(ctx).Errorf("error while closing get jwt token response body: %v", err) } }() @@ -301,8 +301,8 @@ func requestByUserIDURL(authIssuer, userID string) string { } // GetAccount returns all the users for a given profile. Calls Auth0 API. -func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) { - jwtToken, err := am.credentials.Authenticate() +func (am *Auth0Manager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) { + jwtToken, err := am.credentials.Authenticate(ctx) if err != nil { return nil, err } @@ -353,7 +353,7 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) { return nil, err } - log.Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch)) + log.WithContext(ctx).Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch)) err = res.Body.Close() if err != nil { @@ -365,7 +365,7 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) { } if len(batch) == 0 || len(batch) < resultsPerPage { - log.Debugf("finished loading users for accountID %s", accountID) + log.WithContext(ctx).Debugf("finished loading users for accountID %s", accountID) return list, nil } } @@ -374,8 +374,8 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) { } // GetUserDataByID requests user data from auth0 via ID -func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { - jwtToken, err := am.credentials.Authenticate() +func (am *Auth0Manager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { + jwtToken, err := am.credentials.Authenticate(ctx) if err != nil { return nil, err } @@ -414,7 +414,7 @@ func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata) defer func() { err = res.Body.Close() if err != nil { - log.Errorf("error while closing update user app metadata response body: %v", err) + log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err) } }() @@ -426,9 +426,9 @@ func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata) } // UpdateUserAppMetadata updates user app metadata based on userId and metadata map -func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { +func (am *Auth0Manager) UpdateUserAppMetadata(ctx context.Context, userID string, appMetadata AppMetadata) error { - jwtToken, err := am.credentials.Authenticate() + jwtToken, err := am.credentials.Authenticate(ctx) if err != nil { return err } @@ -449,7 +449,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) req.Header.Add("content-type", "application/json") - log.Debugf("updating IdP metadata for user %s", userID) + log.WithContext(ctx).Debugf("updating IdP metadata for user %s", userID) res, err := am.httpClient.Do(req) if err != nil { @@ -466,7 +466,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta defer func() { err = res.Body.Close() if err != nil { - log.Errorf("error while closing update user app metadata response body: %v", err) + log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err) } }() @@ -530,9 +530,9 @@ func buildUserExportRequest() (string, error) { } func (am *Auth0Manager) createRequest( - method string, endpoint string, body io.Reader, + ctx context.Context, method string, endpoint string, body io.Reader, ) (*http.Request, error) { - jwtToken, err := am.credentials.Authenticate() + jwtToken, err := am.credentials.Authenticate(ctx) if err != nil { return nil, err } @@ -548,8 +548,8 @@ func (am *Auth0Manager) createRequest( return req, nil } -func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) { - req, err := am.createRequest("POST", endpoint, strings.NewReader(payloadStr)) +func (am *Auth0Manager) createPostRequest(ctx context.Context, endpoint string, payloadStr string) (*http.Request, error) { + req, err := am.createRequest(ctx, "POST", endpoint, strings.NewReader(payloadStr)) if err != nil { return nil, err } @@ -560,20 +560,20 @@ func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (* // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. -func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) { +func (am *Auth0Manager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { payloadString, err := buildUserExportRequest() if err != nil { return nil, err } - exportJobReq, err := am.createPostRequest("/api/v2/jobs/users-exports", payloadString) + exportJobReq, err := am.createPostRequest(ctx, "/api/v2/jobs/users-exports", payloadString) if err != nil { return nil, err } jobResp, err := am.httpClient.Do(exportJobReq) if err != nil { - log.Debugf("Couldn't get job response %v", err) + log.WithContext(ctx).Debugf("Couldn't get job response %v", err) if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountRequestError() } @@ -583,7 +583,7 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) { defer func() { err = jobResp.Body.Close() if err != nil { - log.Errorf("error while closing update user app metadata response body: %v", err) + log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err) } }() if jobResp.StatusCode != 200 { @@ -597,13 +597,13 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) { body, err := io.ReadAll(jobResp.Body) if err != nil { - log.Debugf("Couldn't read export job response; %v", err) + log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err) return nil, err } err = am.helper.Unmarshal(body, &exportJobResp) if err != nil { - log.Debugf("Couldn't unmarshal export job response; %v", err) + log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err) return nil, err } @@ -614,16 +614,16 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) { return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp) } - log.Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp) + log.WithContext(ctx).Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp) - done, downloadLink, err := am.checkExportJobStatus(exportJobResp.ID) + done, downloadLink, err := am.checkExportJobStatus(ctx, exportJobResp.ID) if err != nil { - log.Debugf("Failed at getting status checks from exportJob; %v", err) + log.WithContext(ctx).Debugf("Failed at getting status checks from exportJob; %v", err) return nil, err } if done { - return am.downloadProfileExport(downloadLink) + return am.downloadProfileExport(ctx, downloadLink) } return nil, fmt.Errorf("failed extracting user profiles from auth0") @@ -632,13 +632,13 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) { // GetUserByEmail searches users with a given email. If no users have been found, this function returns an empty list. // This function can return multiple users. This is due to the Auth0 internals - there could be multiple users with // the same email but different connections that are considered as separate accounts (e.g., Google and username/password). -func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) { - jwtToken, err := am.credentials.Authenticate() +func (am *Auth0Manager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { + jwtToken, err := am.credentials.Authenticate(ctx) if err != nil { return nil, err } reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + url.QueryEscape(email) - body, err := doGetReq(am.httpClient, reqURL, jwtToken.AccessToken) + body, err := doGetReq(ctx, am.httpClient, reqURL, jwtToken.AccessToken) if err != nil { return nil, err } @@ -651,7 +651,7 @@ func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) { err = am.helper.Unmarshal(body, &userResp) if err != nil { - log.Debugf("Couldn't unmarshal export job response; %v", err) + log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err) return nil, err } @@ -659,13 +659,13 @@ func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) { } // CreateUser creates a new user in Auth0 Idp and sends an invite -func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { +func (am *Auth0Manager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) { payloadString, err := buildCreateUserRequestPayload(email, name, accountID, invitedByEmail) if err != nil { return nil, err } - req, err := am.createPostRequest("/api/v2/users", payloadString) + req, err := am.createPostRequest(ctx, "/api/v2/users", payloadString) if err != nil { return nil, err } @@ -676,7 +676,7 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string resp, err := am.httpClient.Do(req) if err != nil { - log.Debugf("Couldn't get job response %v", err) + log.WithContext(ctx).Debugf("Couldn't get job response %v", err) if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountRequestError() } @@ -686,7 +686,7 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string defer func() { err = resp.Body.Close() if err != nil { - log.Errorf("error while closing create user response body: %v", err) + log.WithContext(ctx).Errorf("error while closing create user response body: %v", err) } }() if !(resp.StatusCode == 200 || resp.StatusCode == 201) { @@ -700,13 +700,13 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string body, err := io.ReadAll(resp.Body) if err != nil { - log.Debugf("Couldn't read export job response; %v", err) + log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err) return nil, err } err = am.helper.Unmarshal(body, &createResp) if err != nil { - log.Debugf("Couldn't unmarshal export job response; %v", err) + log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err) return nil, err } @@ -714,14 +714,14 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string return nil, fmt.Errorf("couldn't create user: response %v", resp) } - log.Debugf("created user %s in account %s", createResp.ID, accountID) + log.WithContext(ctx).Debugf("created user %s in account %s", createResp.ID, accountID) return &createResp, nil } // InviteUserByID resend invitations to users who haven't activated, // their accounts prior to the expiration period. -func (am *Auth0Manager) InviteUserByID(userID string) error { +func (am *Auth0Manager) InviteUserByID(ctx context.Context, userID string) error { userVerificationReq := userVerificationJobRequest{ UserID: userID, } @@ -731,14 +731,14 @@ func (am *Auth0Manager) InviteUserByID(userID string) error { return err } - req, err := am.createPostRequest("/api/v2/jobs/verification-email", string(payload)) + req, err := am.createPostRequest(ctx, "/api/v2/jobs/verification-email", string(payload)) if err != nil { return err } resp, err := am.httpClient.Do(req) if err != nil { - log.Debugf("Couldn't get job response %v", err) + log.WithContext(ctx).Debugf("Couldn't get job response %v", err) if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountRequestError() } @@ -748,7 +748,7 @@ func (am *Auth0Manager) InviteUserByID(userID string) error { defer func() { err = resp.Body.Close() if err != nil { - log.Errorf("error while closing invite user response body: %v", err) + log.WithContext(ctx).Errorf("error while closing invite user response body: %v", err) } }() if !(resp.StatusCode == 200 || resp.StatusCode == 201) { @@ -762,15 +762,15 @@ func (am *Auth0Manager) InviteUserByID(userID string) error { } // DeleteUser from Auth0 -func (am *Auth0Manager) DeleteUser(userID string) error { - req, err := am.createRequest(http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil) +func (am *Auth0Manager) DeleteUser(ctx context.Context, userID string) error { + req, err := am.createRequest(ctx, http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil) if err != nil { return err } resp, err := am.httpClient.Do(req) if err != nil { - log.Debugf("execute delete request: %v", err) + log.WithContext(ctx).Debugf("execute delete request: %v", err) if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountRequestError() } @@ -780,7 +780,7 @@ func (am *Auth0Manager) DeleteUser(userID string) error { defer func() { err = resp.Body.Close() if err != nil { - log.Errorf("close delete request body: %v", err) + log.WithContext(ctx).Errorf("close delete request body: %v", err) } }() if resp.StatusCode != 204 { @@ -795,20 +795,20 @@ func (am *Auth0Manager) DeleteUser(userID string) error { // GetAllConnections returns detailed list of all connections filtered by given params. // Note this method is not part of the IDP Manager interface as this is Auth0 specific. -func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, error) { +func (am *Auth0Manager) GetAllConnections(ctx context.Context, strategy []string) ([]Connection, error) { var connections []Connection q := make(url.Values) q.Set("strategy", strings.Join(strategy, ",")) - req, err := am.createRequest(http.MethodGet, "/api/v2/connections?"+q.Encode(), nil) + req, err := am.createRequest(ctx, http.MethodGet, "/api/v2/connections?"+q.Encode(), nil) if err != nil { return connections, err } resp, err := am.httpClient.Do(req) if err != nil { - log.Debugf("execute get connections request: %v", err) + log.WithContext(ctx).Debugf("execute get connections request: %v", err) if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountRequestError() } @@ -818,7 +818,7 @@ func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, erro defer func() { err = resp.Body.Close() if err != nil { - log.Errorf("close get connections request body: %v", err) + log.WithContext(ctx).Errorf("close get connections request body: %v", err) } }() if resp.StatusCode != 200 { @@ -830,13 +830,13 @@ func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, erro body, err := io.ReadAll(resp.Body) if err != nil { - log.Debugf("Couldn't read get connections response; %v", err) + log.WithContext(ctx).Debugf("Couldn't read get connections response; %v", err) return connections, err } err = am.helper.Unmarshal(body, &connections) if err != nil { - log.Debugf("Couldn't unmarshal get connection response; %v", err) + log.WithContext(ctx).Debugf("Couldn't unmarshal get connection response; %v", err) return connections, err } @@ -845,23 +845,23 @@ func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, erro // checkExportJobStatus checks the status of the job created at CreateExportUsersJob. // If the status is "completed", then return the downloadLink -func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) { - ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second) +func (am *Auth0Manager) checkExportJobStatus(ctx context.Context, jobID string) (bool, string, error) { + ctx, cancel := context.WithTimeout(ctx, 90*time.Second) defer cancel() retry := time.NewTicker(10 * time.Second) for { select { case <-ctx.Done(): - log.Debugf("Export job status stopped...\n") + log.WithContext(ctx).Debugf("Export job status stopped...\n") return false, "", ctx.Err() case <-retry.C: - jwtToken, err := am.credentials.Authenticate() + jwtToken, err := am.credentials.Authenticate(ctx) if err != nil { return false, "", err } statusURL := am.authIssuer + "/api/v2/jobs/" + jobID - body, err := doGetReq(am.httpClient, statusURL, jwtToken.AccessToken) + body, err := doGetReq(ctx, am.httpClient, statusURL, jwtToken.AccessToken) if err != nil { return false, "", err } @@ -872,7 +872,7 @@ func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) return false, "", err } - log.Debugf("current export job status is %v", status.Status) + log.WithContext(ctx).Debugf("current export job status is %v", status.Status) if status.Status != "completed" { continue @@ -884,8 +884,8 @@ func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) } // downloadProfileExport downloads user profiles from auth0 batch job -func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*UserData, error) { - body, err := doGetReq(am.httpClient, location, "") +func (am *Auth0Manager) downloadProfileExport(ctx context.Context, location string) (map[string][]*UserData, error) { + body, err := doGetReq(ctx, am.httpClient, location, "") if err != nil { return nil, err } @@ -927,7 +927,7 @@ func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*Us } // Boilerplate implementation for Get Requests. -func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error) { +func doGetReq(ctx context.Context, client ManagerHTTPClient, url, accessToken string) ([]byte, error) { req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err @@ -945,7 +945,7 @@ func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error) defer func() { err = res.Body.Close() if err != nil { - log.Errorf("error while closing body for url %s: %v", url, err) + log.WithContext(ctx).Errorf("error while closing body for url %s: %v", url, err) } }() body, err := io.ReadAll(res.Body) diff --git a/management/server/idp/auth0_test.go b/management/server/idp/auth0_test.go index febc0ab89..de42ced99 100644 --- a/management/server/idp/auth0_test.go +++ b/management/server/idp/auth0_test.go @@ -1,6 +1,7 @@ package idp import ( + "context" "encoding/json" "fmt" "io" @@ -60,7 +61,7 @@ type mockAuth0Credentials struct { err error } -func (mc *mockAuth0Credentials) Authenticate() (JWTToken, error) { +func (mc *mockAuth0Credentials) Authenticate(_ context.Context) (JWTToken, error) { return mc.jwtToken, mc.err } @@ -126,7 +127,7 @@ func TestAuth0_RequestJWTToken(t *testing.T) { helper: testCase.helper, } - res, err := creds.requestJWTToken() + res, err := creds.requestJWTToken(context.Background()) if err != nil { if testCase.expectedFuncExitErrDiff != nil { assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") @@ -295,7 +296,7 @@ func TestAuth0_Authenticate(t *testing.T) { creds.jwtToken.expiresInTime = testCase.inputExpireToken - _, err := creds.Authenticate() + _, err := creds.Authenticate(context.Background()) if err != nil { if testCase.expectedFuncExitErrDiff != nil { assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") @@ -417,7 +418,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) { helper: testCase.helper, } - err := manager.UpdateUserAppMetadata("1", testCase.appMetadata) + err := manager.UpdateUserAppMetadata(context.Background(), "1", testCase.appMetadata) testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) assert.Equal(t, testCase.expectedReqBody, jwtReqClient.reqBody, "request body should match") diff --git a/management/server/idp/authentik.go b/management/server/idp/authentik.go index b39f2b5cb..00d30d645 100644 --- a/management/server/idp/authentik.go +++ b/management/server/idp/authentik.go @@ -116,7 +116,7 @@ func (ac *AuthentikCredentials) jwtStillValid() bool { } // requestJWTToken performs request to get jwt token. -func (ac *AuthentikCredentials) requestJWTToken() (*http.Response, error) { +func (ac *AuthentikCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) { data := url.Values{} data.Set("client_id", ac.clientConfig.ClientID) data.Set("username", ac.clientConfig.Username) @@ -131,7 +131,7 @@ func (ac *AuthentikCredentials) requestJWTToken() (*http.Response, error) { } req.Header.Add("content-type", "application/x-www-form-urlencoded") - log.Debug("requesting new jwt token for authentik idp manager") + log.WithContext(ctx).Debug("requesting new jwt token for authentik idp manager") resp, err := ac.httpClient.Do(req) if err != nil { @@ -183,7 +183,7 @@ func (ac *AuthentikCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) ( } // Authenticate retrieves access token to use the authentik management API. -func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) { +func (ac *AuthentikCredentials) Authenticate(ctx context.Context) (JWTToken, error) { ac.mux.Lock() defer ac.mux.Unlock() @@ -197,7 +197,7 @@ func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) { return ac.jwtToken, nil } - resp, err := ac.requestJWTToken() + resp, err := ac.requestJWTToken(ctx) if err != nil { return ac.jwtToken, err } @@ -214,13 +214,13 @@ func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) { } // UpdateUserAppMetadata updates user app metadata based on userID and metadata map. -func (am *AuthentikManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { +func (am *AuthentikManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { return nil } // GetUserDataByID requests user data from authentik via ID. -func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { - ctx, err := am.authenticationContext() +func (am *AuthentikManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { + ctx, err := am.authenticationContext(ctx) if err != nil { return nil, err } @@ -254,8 +254,8 @@ func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetada } // GetAccount returns all the users for a given profile. -func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) { - users, err := am.getAllUsers() +func (am *AuthentikManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) { + users, err := am.getAllUsers(ctx) if err != nil { return nil, err } @@ -274,8 +274,8 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) { // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. -func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) { - users, err := am.getAllUsers() +func (am *AuthentikManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { + users, err := am.getAllUsers(ctx) if err != nil { return nil, err } @@ -291,12 +291,12 @@ func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) { } // getAllUsers returns all users in a Authentik account. -func (am *AuthentikManager) getAllUsers() ([]*UserData, error) { +func (am *AuthentikManager) getAllUsers(ctx context.Context) ([]*UserData, error) { users := make([]*UserData, 0) page := int32(1) for { - ctx, err := am.authenticationContext() + ctx, err := am.authenticationContext(ctx) if err != nil { return nil, err } @@ -329,14 +329,14 @@ func (am *AuthentikManager) getAllUsers() ([]*UserData, error) { } // CreateUser creates a new user in authentik Idp and sends an invitation. -func (am *AuthentikManager) CreateUser(_, _, _, _ string) (*UserData, error) { +func (am *AuthentikManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) { return nil, fmt.Errorf("method CreateUser not implemented") } // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. -func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) { - ctx, err := am.authenticationContext() +func (am *AuthentikManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { + ctx, err := am.authenticationContext(ctx) if err != nil { return nil, err } @@ -368,13 +368,13 @@ func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) { // InviteUserByID resend invitations to users who haven't activated, // their accounts prior to the expiration period. -func (am *AuthentikManager) InviteUserByID(_ string) error { +func (am *AuthentikManager) InviteUserByID(_ context.Context, _ string) error { return fmt.Errorf("method InviteUserByID not implemented") } // DeleteUser from Authentik -func (am *AuthentikManager) DeleteUser(userID string) error { - ctx, err := am.authenticationContext() +func (am *AuthentikManager) DeleteUser(ctx context.Context, userID string) error { + ctx, err := am.authenticationContext(ctx) if err != nil { return err } @@ -404,8 +404,8 @@ func (am *AuthentikManager) DeleteUser(userID string) error { return nil } -func (am *AuthentikManager) authenticationContext() (context.Context, error) { - jwtToken, err := am.credentials.Authenticate() +func (am *AuthentikManager) authenticationContext(ctx context.Context) (context.Context, error) { + jwtToken, err := am.credentials.Authenticate(ctx) if err != nil { return nil, err } diff --git a/management/server/idp/authentik_test.go b/management/server/idp/authentik_test.go index 342e16384..029acdce3 100644 --- a/management/server/idp/authentik_test.go +++ b/management/server/idp/authentik_test.go @@ -1,6 +1,7 @@ package idp import ( + "context" "fmt" "io" "strings" @@ -138,7 +139,7 @@ func TestAuthentikRequestJWTToken(t *testing.T) { helper: testCase.helper, } - resp, err := creds.requestJWTToken() + resp, err := creds.requestJWTToken(context.Background()) if err != nil { if testCase.expectedFuncExitErrDiff != nil { assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") @@ -304,7 +305,7 @@ func TestAuthentikAuthenticate(t *testing.T) { } creds.jwtToken.expiresInTime = testCase.inputExpireToken - _, err := creds.Authenticate() + _, err := creds.Authenticate(context.Background()) if err != nil { if testCase.expectedFuncExitErrDiff != nil { assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") diff --git a/management/server/idp/azure.go b/management/server/idp/azure.go index 2f21b3b54..35b86764d 100644 --- a/management/server/idp/azure.go +++ b/management/server/idp/azure.go @@ -1,6 +1,7 @@ package idp import ( + "context" "fmt" "io" "net/http" @@ -110,7 +111,7 @@ func (ac *AzureCredentials) jwtStillValid() bool { } // requestJWTToken performs request to get jwt token. -func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) { +func (ac *AzureCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) { data := url.Values{} data.Set("client_id", ac.clientConfig.ClientID) data.Set("client_secret", ac.clientConfig.ClientSecret) @@ -132,7 +133,7 @@ func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) { } req.Header.Add("content-type", "application/x-www-form-urlencoded") - log.Debug("requesting new jwt token for azure idp manager") + log.WithContext(ctx).Debug("requesting new jwt token for azure idp manager") resp, err := ac.httpClient.Do(req) if err != nil { @@ -184,7 +185,7 @@ func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTT } // Authenticate retrieves access token to use the azure Management API. -func (ac *AzureCredentials) Authenticate() (JWTToken, error) { +func (ac *AzureCredentials) Authenticate(ctx context.Context) (JWTToken, error) { ac.mux.Lock() defer ac.mux.Unlock() @@ -198,7 +199,7 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) { return ac.jwtToken, nil } - resp, err := ac.requestJWTToken() + resp, err := ac.requestJWTToken(ctx) if err != nil { return ac.jwtToken, err } @@ -215,16 +216,16 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) { } // CreateUser creates a new user in azure AD Idp. -func (am *AzureManager) CreateUser(_, _, _, _ string) (*UserData, error) { +func (am *AzureManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) { return nil, fmt.Errorf("method CreateUser not implemented") } // GetUserDataByID requests user data from keycloak via ID. -func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { +func (am *AzureManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { q := url.Values{} q.Add("$select", profileFields) - body, err := am.get("users/"+userID, q) + body, err := am.get(ctx, "users/"+userID, q) if err != nil { return nil, err } @@ -247,11 +248,11 @@ func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. -func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) { +func (am *AzureManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { q := url.Values{} q.Add("$select", profileFields) - body, err := am.get("users/"+email, q) + body, err := am.get(ctx, "users/"+email, q) if err != nil { return nil, err } @@ -273,8 +274,8 @@ func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) { } // GetAccount returns all the users for a given profile. -func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) { - users, err := am.getAllUsers() +func (am *AzureManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) { + users, err := am.getAllUsers(ctx) if err != nil { return nil, err } @@ -293,8 +294,8 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) { // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. -func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) { - users, err := am.getAllUsers() +func (am *AzureManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { + users, err := am.getAllUsers(ctx) if err != nil { return nil, err } @@ -310,19 +311,19 @@ func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) { } // UpdateUserAppMetadata updates user app metadata based on userID. -func (am *AzureManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { +func (am *AzureManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { return nil } // InviteUserByID resend invitations to users who haven't activated, // their accounts prior to the expiration period. -func (am *AzureManager) InviteUserByID(_ string) error { +func (am *AzureManager) InviteUserByID(_ context.Context, _ string) error { return fmt.Errorf("method InviteUserByID not implemented") } // DeleteUser from Azure. -func (am *AzureManager) DeleteUser(userID string) error { - jwtToken, err := am.credentials.Authenticate() +func (am *AzureManager) DeleteUser(ctx context.Context, userID string) error { + jwtToken, err := am.credentials.Authenticate(ctx) if err != nil { return err } @@ -335,7 +336,7 @@ func (am *AzureManager) DeleteUser(userID string) error { req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) req.Header.Add("content-type", "application/json") - log.Debugf("delete idp user %s", userID) + log.WithContext(ctx).Debugf("delete idp user %s", userID) resp, err := am.httpClient.Do(req) if err != nil { @@ -358,7 +359,7 @@ func (am *AzureManager) DeleteUser(userID string) error { } // getAllUsers returns all users in an Azure AD account. -func (am *AzureManager) getAllUsers() ([]*UserData, error) { +func (am *AzureManager) getAllUsers(ctx context.Context) ([]*UserData, error) { users := make([]*UserData, 0) q := url.Values{} @@ -366,7 +367,7 @@ func (am *AzureManager) getAllUsers() ([]*UserData, error) { q.Add("$top", "500") for nextLink := "users"; nextLink != ""; { - body, err := am.get(nextLink, q) + body, err := am.get(ctx, nextLink, q) if err != nil { return nil, err } @@ -391,8 +392,8 @@ func (am *AzureManager) getAllUsers() ([]*UserData, error) { } // get perform Get requests. -func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) { - jwtToken, err := am.credentials.Authenticate() +func (am *AzureManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) { + jwtToken, err := am.credentials.Authenticate(ctx) if err != nil { return nil, err } diff --git a/management/server/idp/azure_test.go b/management/server/idp/azure_test.go index b4dc96b23..80e85d2b1 100644 --- a/management/server/idp/azure_test.go +++ b/management/server/idp/azure_test.go @@ -1,6 +1,7 @@ package idp import ( + "context" "fmt" "testing" "time" @@ -101,7 +102,7 @@ func TestAzureAuthenticate(t *testing.T) { } creds.jwtToken.expiresInTime = testCase.inputExpireToken - _, err := creds.Authenticate() + _, err := creds.Authenticate(context.Background()) if err != nil { if testCase.expectedFuncExitErrDiff != nil { assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") diff --git a/management/server/idp/google_workspace.go b/management/server/idp/google_workspace.go index 896fb707b..09ea8c430 100644 --- a/management/server/idp/google_workspace.go +++ b/management/server/idp/google_workspace.go @@ -39,12 +39,12 @@ type GoogleWorkspaceCredentials struct { appMetrics telemetry.AppMetrics } -func (gc *GoogleWorkspaceCredentials) Authenticate() (JWTToken, error) { +func (gc *GoogleWorkspaceCredentials) Authenticate(_ context.Context) (JWTToken, error) { return JWTToken{}, nil } // NewGoogleWorkspaceManager creates a new instance of the GoogleWorkspaceManager. -func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) { +func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) { httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport.MaxIdleConns = 5 @@ -66,7 +66,7 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te } // Create a new Admin SDK Directory service client - adminCredentials, err := getGoogleCredentials(config.ServiceAccountKey) + adminCredentials, err := getGoogleCredentials(ctx, config.ServiceAccountKey) if err != nil { return nil, err } @@ -90,12 +90,12 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te } // UpdateUserAppMetadata updates user app metadata based on userID and metadata map. -func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { +func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { return nil } // GetUserDataByID requests user data from Google Workspace via ID. -func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { +func (gm *GoogleWorkspaceManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { user, err := gm.usersService.Get(userID).Do() if err != nil { return nil, err @@ -112,7 +112,7 @@ func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata App } // GetAccount returns all the users for a given profile. -func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) { +func (gm *GoogleWorkspaceManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) { users, err := gm.getAllUsers() if err != nil { return nil, err @@ -132,7 +132,7 @@ func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, err // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. -func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, error) { +func (gm *GoogleWorkspaceManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) { users, err := gm.getAllUsers() if err != nil { return nil, err @@ -177,13 +177,13 @@ func (gm *GoogleWorkspaceManager) getAllUsers() ([]*UserData, error) { } // CreateUser creates a new user in Google Workspace and sends an invitation. -func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, error) { +func (gm *GoogleWorkspaceManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) { return nil, fmt.Errorf("method CreateUser not implemented") } // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. -func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, error) { +func (gm *GoogleWorkspaceManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) { user, err := gm.usersService.Get(email).Do() if err != nil { return nil, err @@ -201,12 +201,12 @@ func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, err // InviteUserByID resend invitations to users who haven't activated, // their accounts prior to the expiration period. -func (gm *GoogleWorkspaceManager) InviteUserByID(_ string) error { +func (gm *GoogleWorkspaceManager) InviteUserByID(_ context.Context, _ string) error { return fmt.Errorf("method InviteUserByID not implemented") } // DeleteUser from GoogleWorkspace. -func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error { +func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) error { if err := gm.usersService.Delete(userID).Do(); err != nil { return err } @@ -222,8 +222,8 @@ func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error { // It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it. // If that fails, it falls back to using the default Google credentials path. // It returns the retrieved credentials or an error if unsuccessful. -func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error) { - log.Debug("retrieving google credentials from the base64 encoded service account key") +func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*google.Credentials, error) { + log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key") decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey) if err != nil { return nil, fmt.Errorf("failed to decode service account key: %w", err) @@ -239,8 +239,8 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error) return creds, nil } - log.Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err) - log.Debug("falling back to default google credentials location") + log.WithContext(ctx).Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err) + log.WithContext(ctx).Debug("falling back to default google credentials location") creds, err = google.FindDefaultCredentials( context.Background(), diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index 7adb76f40..419220942 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -1,6 +1,7 @@ package idp import ( + "context" "fmt" "net/http" "strings" @@ -16,14 +17,14 @@ const ( // Manager idp manager interface type Manager interface { - UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error - GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error) - GetAccount(accountId string) ([]*UserData, error) - GetAllAccounts() (map[string][]*UserData, error) - CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) - GetUserByEmail(email string) ([]*UserData, error) - InviteUserByID(userID string) error - DeleteUser(userID string) error + UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error + GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) + GetAccount(ctx context.Context, accountId string) ([]*UserData, error) + GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) + CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) + GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) + InviteUserByID(ctx context.Context, userID string) error + DeleteUser(ctx context.Context, userID string) error } // ClientConfig defines common client configuration for all IdP manager @@ -51,7 +52,7 @@ type Config struct { // ManagerCredentials interface that authenticates using the credential of each type of idp type ManagerCredentials interface { - Authenticate() (JWTToken, error) + Authenticate(ctx context.Context) (JWTToken, error) } // ManagerHTTPClient http client interface for API calls @@ -91,7 +92,7 @@ type JWTToken struct { } // NewManager returns a new idp manager based on the configuration that it receives -func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) { +func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetrics) (Manager, error) { if config.ClientConfig != nil { config.ClientConfig.Issuer = strings.TrimSuffix(config.ClientConfig.Issuer, "/") } @@ -175,7 +176,7 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"], CustomerID: config.ExtraConfig["CustomerId"], } - return NewGoogleWorkspaceManager(googleClientConfig, appMetrics) + return NewGoogleWorkspaceManager(ctx, googleClientConfig, appMetrics) case "jumpcloud": jumpcloudConfig := JumpCloudClientConfig{ APIToken: config.ExtraConfig["ApiToken"], diff --git a/management/server/idp/jumpcloud.go b/management/server/idp/jumpcloud.go index 0115b4049..6345e424a 100644 --- a/management/server/idp/jumpcloud.go +++ b/management/server/idp/jumpcloud.go @@ -74,7 +74,7 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM } // Authenticate retrieves access token to use the JumpCloud user API. -func (jc *JumpCloudCredentials) Authenticate() (JWTToken, error) { +func (jc *JumpCloudCredentials) Authenticate(_ context.Context) (JWTToken, error) { return JWTToken{}, nil } @@ -85,12 +85,12 @@ func (jm *JumpCloudManager) authenticationContext() context.Context { } // UpdateUserAppMetadata updates user app metadata based on userID and metadata map. -func (jm *JumpCloudManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { +func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { return nil } // GetUserDataByID requests user data from JumpCloud via ID. -func (jm *JumpCloudManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { +func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { authCtx := jm.authenticationContext() user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil) if err != nil { @@ -116,7 +116,7 @@ func (jm *JumpCloudManager) GetUserDataByID(userID string, appMetadata AppMetada } // GetAccount returns all the users for a given profile. -func (jm *JumpCloudManager) GetAccount(accountID string) ([]*UserData, error) { +func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) { authCtx := jm.authenticationContext() userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil) if err != nil { @@ -148,7 +148,7 @@ func (jm *JumpCloudManager) GetAccount(accountID string) ([]*UserData, error) { // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. -func (jm *JumpCloudManager) GetAllAccounts() (map[string][]*UserData, error) { +func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) { authCtx := jm.authenticationContext() userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil) if err != nil { @@ -177,13 +177,13 @@ func (jm *JumpCloudManager) GetAllAccounts() (map[string][]*UserData, error) { } // CreateUser creates a new user in JumpCloud Idp and sends an invitation. -func (jm *JumpCloudManager) CreateUser(_, _, _, _ string) (*UserData, error) { +func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) { return nil, fmt.Errorf("method CreateUser not implemented") } // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. -func (jm *JumpCloudManager) GetUserByEmail(email string) ([]*UserData, error) { +func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) { searchFilter := map[string]interface{}{ "searchFilter": map[string]interface{}{ "filter": []string{email}, @@ -219,12 +219,12 @@ func (jm *JumpCloudManager) GetUserByEmail(email string) ([]*UserData, error) { // InviteUserByID resend invitations to users who haven't activated, // their accounts prior to the expiration period. -func (jm *JumpCloudManager) InviteUserByID(_ string) error { +func (jm *JumpCloudManager) InviteUserByID(_ context.Context, _ string) error { return fmt.Errorf("method InviteUserByID not implemented") } // DeleteUser from jumpCloud directory -func (jm *JumpCloudManager) DeleteUser(userID string) error { +func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error { authCtx := jm.authenticationContext() _, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil) if err != nil { diff --git a/management/server/idp/keycloak.go b/management/server/idp/keycloak.go index 3a6f80d03..07d84058c 100644 --- a/management/server/idp/keycloak.go +++ b/management/server/idp/keycloak.go @@ -1,6 +1,7 @@ package idp import ( + "context" "fmt" "io" "net/http" @@ -109,7 +110,7 @@ func (kc *KeycloakCredentials) jwtStillValid() bool { } // requestJWTToken performs request to get jwt token. -func (kc *KeycloakCredentials) requestJWTToken() (*http.Response, error) { +func (kc *KeycloakCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) { data := url.Values{} data.Set("client_id", kc.clientConfig.ClientID) data.Set("client_secret", kc.clientConfig.ClientSecret) @@ -122,7 +123,7 @@ func (kc *KeycloakCredentials) requestJWTToken() (*http.Response, error) { } req.Header.Add("content-type", "application/x-www-form-urlencoded") - log.Debug("requesting new jwt token for keycloak idp manager") + log.WithContext(ctx).Debug("requesting new jwt token for keycloak idp manager") resp, err := kc.httpClient.Do(req) if err != nil { @@ -174,7 +175,7 @@ func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (J } // Authenticate retrieves access token to use the keycloak Management API. -func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) { +func (kc *KeycloakCredentials) Authenticate(ctx context.Context) (JWTToken, error) { kc.mux.Lock() defer kc.mux.Unlock() @@ -188,7 +189,7 @@ func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) { return kc.jwtToken, nil } - resp, err := kc.requestJWTToken() + resp, err := kc.requestJWTToken(ctx) if err != nil { return kc.jwtToken, err } @@ -205,18 +206,18 @@ func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) { } // CreateUser creates a new user in keycloak Idp and sends an invite. -func (km *KeycloakManager) CreateUser(_, _, _, _ string) (*UserData, error) { +func (km *KeycloakManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) { return nil, fmt.Errorf("method CreateUser not implemented") } // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. -func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) { +func (km *KeycloakManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { q := url.Values{} q.Add("email", email) q.Add("exact", "true") - body, err := km.get("users", q) + body, err := km.get(ctx, "users", q) if err != nil { return nil, err } @@ -240,8 +241,8 @@ func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) { } // GetUserDataByID requests user data from keycloak via ID. -func (km *KeycloakManager) GetUserDataByID(userID string, _ AppMetadata) (*UserData, error) { - body, err := km.get("users/"+userID, nil) +func (km *KeycloakManager) GetUserDataByID(ctx context.Context, userID string, _ AppMetadata) (*UserData, error) { + body, err := km.get(ctx, "users/"+userID, nil) if err != nil { return nil, err } @@ -260,8 +261,8 @@ func (km *KeycloakManager) GetUserDataByID(userID string, _ AppMetadata) (*UserD } // GetAccount returns all the users for a given account profile. -func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) { - profiles, err := km.fetchAllUserProfiles() +func (km *KeycloakManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) { + profiles, err := km.fetchAllUserProfiles(ctx) if err != nil { return nil, err } @@ -283,8 +284,8 @@ func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) { // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. -func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) { - profiles, err := km.fetchAllUserProfiles() +func (km *KeycloakManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { + profiles, err := km.fetchAllUserProfiles(ctx) if err != nil { return nil, err } @@ -303,19 +304,19 @@ func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) { } // UpdateUserAppMetadata updates user app metadata based on userID and metadata map. -func (km *KeycloakManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { +func (km *KeycloakManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { return nil } // InviteUserByID resend invitations to users who haven't activated, // their accounts prior to the expiration period. -func (km *KeycloakManager) InviteUserByID(_ string) error { +func (km *KeycloakManager) InviteUserByID(_ context.Context, _ string) error { return fmt.Errorf("method InviteUserByID not implemented") } // DeleteUser from Keycloak by user ID. -func (km *KeycloakManager) DeleteUser(userID string) error { - jwtToken, err := km.credentials.Authenticate() +func (km *KeycloakManager) DeleteUser(ctx context.Context, userID string) error { + jwtToken, err := km.credentials.Authenticate(ctx) if err != nil { return err } @@ -353,8 +354,8 @@ func (km *KeycloakManager) DeleteUser(userID string) error { return nil } -func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) { - totalUsers, err := km.totalUsersCount() +func (km *KeycloakManager) fetchAllUserProfiles(ctx context.Context) ([]keycloakProfile, error) { + totalUsers, err := km.totalUsersCount(ctx) if err != nil { return nil, err } @@ -362,7 +363,7 @@ func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) { q := url.Values{} q.Add("max", fmt.Sprint(*totalUsers)) - body, err := km.get("users", q) + body, err := km.get(ctx, "users", q) if err != nil { return nil, err } @@ -377,8 +378,8 @@ func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) { } // get perform Get requests. -func (km *KeycloakManager) get(resource string, q url.Values) ([]byte, error) { - jwtToken, err := km.credentials.Authenticate() +func (km *KeycloakManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) { + jwtToken, err := km.credentials.Authenticate(ctx) if err != nil { return nil, err } @@ -414,8 +415,8 @@ func (km *KeycloakManager) get(resource string, q url.Values) ([]byte, error) { // totalUsersCount returns the total count of all user created. // Used when fetching all registered accounts with pagination. -func (km *KeycloakManager) totalUsersCount() (*int, error) { - body, err := km.get("users/count", nil) +func (km *KeycloakManager) totalUsersCount(ctx context.Context) (*int, error) { + body, err := km.get(ctx, "users/count", nil) if err != nil { return nil, err } diff --git a/management/server/idp/keycloak_test.go b/management/server/idp/keycloak_test.go index 9b6c1d3c6..0daca0671 100644 --- a/management/server/idp/keycloak_test.go +++ b/management/server/idp/keycloak_test.go @@ -1,6 +1,7 @@ package idp import ( + "context" "fmt" "io" "strings" @@ -128,7 +129,7 @@ func TestKeycloakRequestJWTToken(t *testing.T) { helper: testCase.helper, } - resp, err := creds.requestJWTToken() + resp, err := creds.requestJWTToken(context.Background()) if err != nil { if testCase.expectedFuncExitErrDiff != nil { assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") @@ -294,7 +295,7 @@ func TestKeycloakAuthenticate(t *testing.T) { } creds.jwtToken.expiresInTime = testCase.inputExpireToken - _, err := creds.Authenticate() + _, err := creds.Authenticate(context.Background()) if err != nil { if testCase.expectedFuncExitErrDiff != nil { assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") diff --git a/management/server/idp/mock.go b/management/server/idp/mock.go index 7605466e7..a07e375bf 100644 --- a/management/server/idp/mock.go +++ b/management/server/idp/mock.go @@ -1,77 +1,79 @@ package idp +import "context" + // MockIDP is a mock implementation of the IDP interface type MockIDP struct { - UpdateUserAppMetadataFunc func(userId string, appMetadata AppMetadata) error - GetUserDataByIDFunc func(userId string, appMetadata AppMetadata) (*UserData, error) - GetAccountFunc func(accountId string) ([]*UserData, error) - GetAllAccountsFunc func() (map[string][]*UserData, error) - CreateUserFunc func(email, name, accountID, invitedByEmail string) (*UserData, error) - GetUserByEmailFunc func(email string) ([]*UserData, error) - InviteUserByIDFunc func(userID string) error - DeleteUserFunc func(userID string) error + UpdateUserAppMetadataFunc func(ctx context.Context, userId string, appMetadata AppMetadata) error + GetUserDataByIDFunc func(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) + GetAccountFunc func(ctx context.Context, accountId string) ([]*UserData, error) + GetAllAccountsFunc func(ctx context.Context) (map[string][]*UserData, error) + CreateUserFunc func(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) + GetUserByEmailFunc func(ctx context.Context, email string) ([]*UserData, error) + InviteUserByIDFunc func(ctx context.Context, userID string) error + DeleteUserFunc func(ctx context.Context, userID string) error } // UpdateUserAppMetadata is a mock implementation of the IDP interface UpdateUserAppMetadata method -func (m *MockIDP) UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error { +func (m *MockIDP) UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error { if m.UpdateUserAppMetadataFunc != nil { - return m.UpdateUserAppMetadataFunc(userId, appMetadata) + return m.UpdateUserAppMetadataFunc(ctx, userId, appMetadata) } return nil } // GetUserDataByID is a mock implementation of the IDP interface GetUserDataByID method -func (m *MockIDP) GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error) { +func (m *MockIDP) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) { if m.GetUserDataByIDFunc != nil { - return m.GetUserDataByIDFunc(userId, appMetadata) + return m.GetUserDataByIDFunc(ctx, userId, appMetadata) } return nil, nil } // GetAccount is a mock implementation of the IDP interface GetAccount method -func (m *MockIDP) GetAccount(accountId string) ([]*UserData, error) { +func (m *MockIDP) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) { if m.GetAccountFunc != nil { - return m.GetAccountFunc(accountId) + return m.GetAccountFunc(ctx, accountId) } return nil, nil } // GetAllAccounts is a mock implementation of the IDP interface GetAllAccounts method -func (m *MockIDP) GetAllAccounts() (map[string][]*UserData, error) { +func (m *MockIDP) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { if m.GetAllAccountsFunc != nil { - return m.GetAllAccountsFunc() + return m.GetAllAccountsFunc(ctx) } return nil, nil } // CreateUser is a mock implementation of the IDP interface CreateUser method -func (m *MockIDP) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { +func (m *MockIDP) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) { if m.CreateUserFunc != nil { - return m.CreateUserFunc(email, name, accountID, invitedByEmail) + return m.CreateUserFunc(ctx, email, name, accountID, invitedByEmail) } return nil, nil } // GetUserByEmail is a mock implementation of the IDP interface GetUserByEmail method -func (m *MockIDP) GetUserByEmail(email string) ([]*UserData, error) { +func (m *MockIDP) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { if m.GetUserByEmailFunc != nil { - return m.GetUserByEmailFunc(email) + return m.GetUserByEmailFunc(ctx, email) } return nil, nil } // InviteUserByID is a mock implementation of the IDP interface InviteUserByID method -func (m *MockIDP) InviteUserByID(userID string) error { +func (m *MockIDP) InviteUserByID(ctx context.Context, userID string) error { if m.InviteUserByIDFunc != nil { - return m.InviteUserByIDFunc(userID) + return m.InviteUserByIDFunc(ctx, userID) } return nil } // DeleteUser is a mock implementation of the IDP interface DeleteUser method -func (m *MockIDP) DeleteUser(userID string) error { +func (m *MockIDP) DeleteUser(ctx context.Context, userID string) error { if m.DeleteUserFunc != nil { - return m.DeleteUserFunc(userID) + return m.DeleteUserFunc(ctx, userID) } return nil } diff --git a/management/server/idp/okta.go b/management/server/idp/okta.go index c8d33a207..b9cd006be 100644 --- a/management/server/idp/okta.go +++ b/management/server/idp/okta.go @@ -94,17 +94,17 @@ func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (* } // Authenticate retrieves access token to use the okta user API. -func (oc *OktaCredentials) Authenticate() (JWTToken, error) { +func (oc *OktaCredentials) Authenticate(_ context.Context) (JWTToken, error) { return JWTToken{}, nil } // CreateUser creates a new user in okta Idp and sends an invitation. -func (om *OktaManager) CreateUser(_, _, _, _ string) (*UserData, error) { +func (om *OktaManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) { return nil, fmt.Errorf("method CreateUser not implemented") } // GetUserDataByID requests user data from keycloak via ID. -func (om *OktaManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { +func (om *OktaManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { user, resp, err := om.client.User.GetUser(context.Background(), userID) if err != nil { return nil, err @@ -132,7 +132,7 @@ func (om *OktaManager) GetUserDataByID(userID string, appMetadata AppMetadata) ( // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. -func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) { +func (om *OktaManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) { user, resp, err := om.client.User.GetUser(context.Background(), url.QueryEscape(email)) if err != nil { return nil, err @@ -160,7 +160,7 @@ func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) { } // GetAccount returns all the users for a given profile. -func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) { +func (om *OktaManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) { users, err := om.getAllUsers() if err != nil { return nil, err @@ -180,7 +180,7 @@ func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) { // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. -func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) { +func (om *OktaManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) { users, err := om.getAllUsers() if err != nil { return nil, err @@ -242,18 +242,18 @@ func (om *OktaManager) getAllUsers() ([]*UserData, error) { } // UpdateUserAppMetadata updates user app metadata based on userID and metadata map. -func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { +func (om *OktaManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { return nil } // InviteUserByID resend invitations to users who haven't activated, // their accounts prior to the expiration period. -func (om *OktaManager) InviteUserByID(_ string) error { +func (om *OktaManager) InviteUserByID(_ context.Context, _ string) error { return fmt.Errorf("method InviteUserByID not implemented") } // DeleteUser from Okta -func (om *OktaManager) DeleteUser(userID string) error { +func (om *OktaManager) DeleteUser(_ context.Context, userID string) error { resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil) if err != nil { return err diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go index 9021d6752..729b49733 100644 --- a/management/server/idp/zitadel.go +++ b/management/server/idp/zitadel.go @@ -1,6 +1,7 @@ package idp import ( + "context" "fmt" "io" "net/http" @@ -149,7 +150,7 @@ func (zc *ZitadelCredentials) jwtStillValid() bool { } // requestJWTToken performs request to get jwt token. -func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) { +func (zc *ZitadelCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) { data := url.Values{} data.Set("client_id", zc.clientConfig.ClientID) data.Set("client_secret", zc.clientConfig.ClientSecret) @@ -163,7 +164,7 @@ func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) { } req.Header.Add("content-type", "application/x-www-form-urlencoded") - log.Debug("requesting new jwt token for zitadel idp manager") + log.WithContext(ctx).Debug("requesting new jwt token for zitadel idp manager") resp, err := zc.httpClient.Do(req) if err != nil { @@ -215,7 +216,7 @@ func (zc *ZitadelCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JW } // Authenticate retrieves access token to use the Zitadel Management API. -func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) { +func (zc *ZitadelCredentials) Authenticate(ctx context.Context) (JWTToken, error) { zc.mux.Lock() defer zc.mux.Unlock() @@ -229,7 +230,7 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) { return zc.jwtToken, nil } - resp, err := zc.requestJWTToken() + resp, err := zc.requestJWTToken(ctx) if err != nil { return zc.jwtToken, err } @@ -246,7 +247,7 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) { } // CreateUser creates a new user in zitadel Idp and sends an invite via Zitadel. -func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) { +func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) { firstLast := strings.SplitN(name, " ", 2) var addUser = map[string]any{ @@ -269,7 +270,7 @@ func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail stri return nil, err } - body, err := zm.post("users/human/_import", string(payload)) + body, err := zm.post(ctx, "users/human/_import", string(payload)) if err != nil { return nil, err } @@ -300,7 +301,7 @@ func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail stri // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. -func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) { +func (zm *ZitadelManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { searchByEmail := zitadelAttributes{ "queries": { { @@ -316,7 +317,7 @@ func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) { return nil, err } - body, err := zm.post("users/_search", string(payload)) + body, err := zm.post(ctx, "users/_search", string(payload)) if err != nil { return nil, err } @@ -340,8 +341,8 @@ func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) { } // GetUserDataByID requests user data from zitadel via ID. -func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { - body, err := zm.get("users/"+userID, nil) +func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) { + body, err := zm.get(ctx, "users/"+userID, nil) if err != nil { return nil, err } @@ -363,8 +364,8 @@ func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata } // GetAccount returns all the users for a given profile. -func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) { - body, err := zm.post("users/_search", "") +func (zm *ZitadelManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) { + body, err := zm.post(ctx, "users/_search", "") if err != nil { return nil, err } @@ -392,8 +393,8 @@ func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) { // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. -func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) { - body, err := zm.post("users/_search", "") +func (zm *ZitadelManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { + body, err := zm.post(ctx, "users/_search", "") if err != nil { return nil, err } @@ -419,7 +420,7 @@ func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) { // UpdateUserAppMetadata updates user app metadata based on userID and metadata map. // Metadata values are base64 encoded. -func (zm *ZitadelManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error { +func (zm *ZitadelManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { return nil } @@ -429,7 +430,7 @@ type inviteUserRequest struct { // InviteUserByID resend invitations to users who haven't activated, // their accounts prior to the expiration period. -func (zm *ZitadelManager) InviteUserByID(userID string) error { +func (zm *ZitadelManager) InviteUserByID(ctx context.Context, userID string) error { inviteUser := inviteUserRequest{ Email: userID, } @@ -440,14 +441,14 @@ func (zm *ZitadelManager) InviteUserByID(userID string) error { } // don't care about the body in the response - _, err = zm.post(fmt.Sprintf("users/%s/_resend_initialization", userID), string(payload)) + _, err = zm.post(ctx, fmt.Sprintf("users/%s/_resend_initialization", userID), string(payload)) return err } // DeleteUser from Zitadel -func (zm *ZitadelManager) DeleteUser(userID string) error { +func (zm *ZitadelManager) DeleteUser(ctx context.Context, userID string) error { resource := fmt.Sprintf("users/%s", userID) - if err := zm.delete(resource); err != nil { + if err := zm.delete(ctx, resource); err != nil { return err } @@ -459,8 +460,8 @@ func (zm *ZitadelManager) DeleteUser(userID string) error { } // post perform Post requests. -func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) { - jwtToken, err := zm.credentials.Authenticate() +func (zm *ZitadelManager) post(ctx context.Context, resource string, body string) ([]byte, error) { + jwtToken, err := zm.credentials.Authenticate(ctx) if err != nil { return nil, err } @@ -495,8 +496,8 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) { } // delete perform Delete requests. -func (zm *ZitadelManager) delete(resource string) error { - jwtToken, err := zm.credentials.Authenticate() +func (zm *ZitadelManager) delete(ctx context.Context, resource string) error { + jwtToken, err := zm.credentials.Authenticate(ctx) if err != nil { return err } @@ -531,8 +532,8 @@ func (zm *ZitadelManager) delete(resource string) error { } // get perform Get requests. -func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) { - jwtToken, err := zm.credentials.Authenticate() +func (zm *ZitadelManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) { + jwtToken, err := zm.credentials.Authenticate(ctx) if err != nil { return nil, err } diff --git a/management/server/idp/zitadel_test.go b/management/server/idp/zitadel_test.go index 9a771b36a..6bc612e78 100644 --- a/management/server/idp/zitadel_test.go +++ b/management/server/idp/zitadel_test.go @@ -1,6 +1,7 @@ package idp import ( + "context" "fmt" "io" "strings" @@ -108,7 +109,7 @@ func TestZitadelRequestJWTToken(t *testing.T) { helper: testCase.helper, } - resp, err := creds.requestJWTToken() + resp, err := creds.requestJWTToken(context.Background()) if err != nil { if testCase.expectedFuncExitErrDiff != nil { assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") @@ -274,7 +275,7 @@ func TestZitadelAuthenticate(t *testing.T) { } creds.jwtToken.expiresInTime = testCase.inputExpireToken - _, err := creds.Authenticate() + _, err := creds.Authenticate(context.Background()) if err != nil { if testCase.expectedFuncExitErrDiff != nil { assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 198f8d527..05537ada4 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -1,9 +1,10 @@ package server import ( + "context" "errors" - "github.com/google/martian/v3/log" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" ) @@ -19,22 +20,22 @@ import ( // // Returns: // - error: An error if any occurred during the process, otherwise returns nil -func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error { - ok, err := am.GroupValidation(accountID, groups) +func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error { + ok, err := am.GroupValidation(ctx, accountID, groups) if err != nil { - log.Debugf("error validating groups: %s", err.Error()) + log.WithContext(ctx).Debugf("error validating groups: %s", err.Error()) return err } if !ok { - log.Debugf("invalid groups") + log.WithContext(ctx).Debugf("invalid groups") return errors.New("invalid groups") } - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - a, err := am.Store.GetAccountByUser(userID) + a, err := am.Store.GetAccountByUser(ctx, userID) if err != nil { return err } @@ -48,14 +49,14 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID strin a.Settings.Extra = extra } extra.IntegratedValidatorGroups = groups - return am.Store.SaveAccount(a) + return am.Store.SaveAccount(ctx, a) } -func (am *DefaultAccountManager) GroupValidation(accountId string, groups []string) (bool, error) { +func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) { if len(groups) == 0 { return true, nil } - accountsGroups, err := am.ListGroups(accountId) + accountsGroups, err := am.ListGroups(ctx, accountId) if err != nil { return false, err } diff --git a/management/server/integrated_validator/interface.go b/management/server/integrated_validator/interface.go index ae9698f79..6c9a3e44e 100644 --- a/management/server/integrated_validator/interface.go +++ b/management/server/integrated_validator/interface.go @@ -1,6 +1,8 @@ package integrated_validator import ( + "context" + "github.com/netbirdio/netbird/management/server/account" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -8,12 +10,12 @@ import ( // IntegratedValidator interface exists to avoid the circle dependencies type IntegratedValidator interface { - ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error - ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) - PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer - IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) + ValidateExtraSettings(ctx context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error + ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) + PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer + IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) - PeerDeleted(accountID, peerID string) error + PeerDeleted(ctx context.Context, accountID, peerID string) error SetPeerInvalidationListener(fn func(accountID string)) - Stop() + Stop(ctx context.Context) } diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go index f218c1aa9..c3417a769 100644 --- a/management/server/jwtclaims/jwtValidator.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -2,6 +2,7 @@ package jwtclaims import ( "bytes" + "context" "crypto/rsa" "crypto/x509" "encoding/base64" @@ -69,8 +70,8 @@ type JWTValidator struct { } // NewJWTValidator constructor -func NewJWTValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) { - keys, err := getPemKeys(keysLocation) +func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) { + keys, err := getPemKeys(ctx, keysLocation) if err != nil { return nil, err } @@ -102,19 +103,19 @@ func NewJWTValidator(issuer string, audienceList []string, keysLocation string, lock.Lock() defer lock.Unlock() - refreshedKeys, err := getPemKeys(keysLocation) + refreshedKeys, err := getPemKeys(ctx, keysLocation) if err != nil { - log.Debugf("cannot get JSONWebKey: %v, falling back to old keys", err) + log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err) refreshedKeys = keys } - log.Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC()) + log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC()) keys = refreshedKeys } } - cert, err := getPemCert(token, keys) + cert, err := getPemCert(ctx, token, keys) if err != nil { return nil, err } @@ -136,19 +137,19 @@ func NewJWTValidator(issuer string, audienceList []string, keysLocation string, } // ValidateAndParse validates the token and returns the parsed token -func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) { +func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { // If the token is empty... if token == "" { // Check if it was required if m.options.CredentialsOptional { - log.Debugf("no credentials found (CredentialsOptional=true)") + log.WithContext(ctx).Debugf("no credentials found (CredentialsOptional=true)") // No error, just no token (and that is ok given that CredentialsOptional is true) return nil, nil //nolint:nilnil } // If we get here, the required token is missing errorMsg := "required authorization token not found" - log.Debugf(" Error: No credentials found (CredentialsOptional=false)") + log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)") return nil, fmt.Errorf(errorMsg) } @@ -157,7 +158,7 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) { // Check if there was an error in parsing... if err != nil { - log.Errorf("error parsing token: %v", err) + log.WithContext(ctx).Errorf("error parsing token: %v", err) return nil, fmt.Errorf("Error parsing token: %w", err) } @@ -165,14 +166,14 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) { errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s", m.options.SigningMethod.Alg(), parsedToken.Header["alg"]) - log.Debugf("error validating token algorithm: %s", errorMsg) + log.WithContext(ctx).Debugf("error validating token algorithm: %s", errorMsg) return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg) } // Check if the parsed token is valid... if !parsedToken.Valid { errorMsg := "token is invalid" - log.Debugf(errorMsg) + log.WithContext(ctx).Debugf(errorMsg) return nil, errors.New(errorMsg) } @@ -184,7 +185,7 @@ func (jwks *Jwks) stillValid() bool { return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime) } -func getPemKeys(keysLocation string) (*Jwks, error) { +func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) { resp, err := http.Get(keysLocation) if err != nil { return nil, err @@ -198,13 +199,13 @@ func getPemKeys(keysLocation string) (*Jwks, error) { } cacheControlHeader := resp.Header.Get("Cache-Control") - expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader) + expiresIn := getMaxAgeFromCacheHeader(ctx, cacheControlHeader) jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second) return jwks, err } -func getPemCert(token *jwt.Token, jwks *Jwks) (string, error) { +func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, error) { // todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time cert := "" @@ -217,7 +218,7 @@ func getPemCert(token *jwt.Token, jwks *Jwks) (string, error) { cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----" return cert, nil } - log.Debugf("generating validation pem from JWK") + log.WithContext(ctx).Debugf("generating validation pem from JWK") return generatePemFromJWK(jwks.Keys[k]) } @@ -284,7 +285,7 @@ func convertExponentStringToInt(stringExponent string) (int, error) { } // getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header -func getMaxAgeFromCacheHeader(cacheControl string) int { +func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int { // Split into individual directives directives := strings.Split(cacheControl, ",") @@ -295,7 +296,7 @@ func getMaxAgeFromCacheHeader(cacheControl string) int { maxAgeStr := strings.TrimPrefix(directive, "max-age=") maxAge, err := strconv.Atoi(maxAgeStr) if err != nil { - log.Debugf("error parsing max-age: %v", err) + log.WithContext(ctx).Debugf("error parsing max-age: %v", err) return 0 } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index e49ea5338..e1f7787f2 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -406,7 +406,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) return nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := NewTestStoreFromJson(config.Datadir) + store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir) if err != nil { return nil, "", err } @@ -414,7 +414,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) peersUpdateManager := NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} - accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", + accountManager, err := BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) if err != nil { return nil, "", err @@ -422,7 +422,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) ephemeralMgr := NewEphemeralManager(store, accountManager) - mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr) + mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr) if err != nil { return nil, "", err } diff --git a/management/server/management_test.go b/management/server/management_test.go index 4e997d4d9..092567607 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -451,11 +451,11 @@ var _ = Describe("Management service", func() { type MocIntegratedValidator struct { } -func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { +func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { return nil } -func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { +func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { return update, nil } @@ -467,15 +467,15 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[s return validatedPeers, nil } -func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { +func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { return peer } -func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { +func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { return false, false, nil } -func (MocIntegratedValidator) PeerDeleted(_, _ string) error { +func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { return nil } @@ -483,7 +483,7 @@ func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string) } -func (MocIntegratedValidator) Stop() {} +func (MocIntegratedValidator) Stop(_ context.Context) {} func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse { defer GinkgoRecover() @@ -534,20 +534,20 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) { Expect(err).NotTo(HaveOccurred()) s := grpc.NewServer() - store, _, err := server.NewTestStoreFromJson(config.Datadir) + store, _, err := server.NewTestStoreFromJson(context.Background(), config.Datadir) if err != nil { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} - accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", + accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) if err != nil { log.Fatalf("failed creating a manager: %v", err) } turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) - mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil) Expect(err).NotTo(HaveOccurred()) mgmtProto.RegisterManagementServiceServer(s, mgmtServer) go func() { diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index 9da1e577e..357af6e56 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -46,7 +46,7 @@ type properties map[string]interface{} // DataSource metric data source type DataSource interface { - GetAllAccounts() []*server.Account + GetAllAccounts(ctx context.Context) []*server.Account GetStoreEngine() server.StoreEngine } @@ -81,29 +81,29 @@ func NewWorker(ctx context.Context, id string, dataSource DataSource, connManage } // Run runs the metrics worker -func (w *Worker) Run() { +func (w *Worker) Run(ctx context.Context) { pushTicker := time.NewTicker(defaultPushInterval) for { select { case <-w.ctx.Done(): return case <-pushTicker.C: - err := w.sendMetrics() + err := w.sendMetrics(ctx) if err != nil { - log.Error(err) + log.WithContext(ctx).Error(err) } w.lastRun = time.Now() } } } -func (w *Worker) sendMetrics() error { +func (w *Worker) sendMetrics(ctx context.Context) error { apiKey, err := getAPIKey(w.ctx) if err != nil { return err } - payload := w.generatePayload(apiKey) + payload := w.generatePayload(ctx, apiKey) payloadString, err := buildMetricsPayload(payload) if err != nil { @@ -125,7 +125,7 @@ func (w *Worker) sendMetrics() error { defer func() { err = jobResp.Body.Close() if err != nil { - log.Errorf("error while closing update metrics response body: %v", err) + log.WithContext(ctx).Errorf("error while closing update metrics response body: %v", err) } }() @@ -133,15 +133,15 @@ func (w *Worker) sendMetrics() error { return fmt.Errorf("unable to push anonymous metrics, got statusCode %d", jobResp.StatusCode) } - log.Infof("sent anonymous metrics, next push will happen in %s. "+ + log.WithContext(ctx).Infof("sent anonymous metrics, next push will happen in %s. "+ "You can disable these metrics by running with flag --disable-anonymous-metrics,"+ " see more information at https://netbird.io/docs/FAQ/metrics-collection", defaultPushInterval) return nil } -func (w *Worker) generatePayload(apiKey string) pushPayload { - properties := w.generateProperties() +func (w *Worker) generatePayload(ctx context.Context, apiKey string) pushPayload { + properties := w.generateProperties(ctx) return pushPayload{ APIKey: apiKey, @@ -152,7 +152,7 @@ func (w *Worker) generatePayload(apiKey string) pushPayload { } } -func (w *Worker) generateProperties() properties { +func (w *Worker) generateProperties(ctx context.Context) properties { var ( uptime float64 accounts int @@ -192,7 +192,7 @@ func (w *Worker) generateProperties() properties { connections := w.connManager.GetAllConnectedPeers() version = nbversion.NetbirdVersion() - for _, account := range w.dataSource.GetAllAccounts() { + for _, account := range w.dataSource.GetAllAccounts(ctx) { accounts++ if account.Settings.PeerLoginExpirationEnabled { @@ -342,7 +342,7 @@ func getAPIKey(ctx context.Context) (string, error) { defer func() { err = response.Body.Close() if err != nil { - log.Errorf("error while closing metrics token response body: %v", err) + log.WithContext(ctx).Errorf("error while closing metrics token response body: %v", err) } }() diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index c5b18607a..2ac2d68a0 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -1,6 +1,7 @@ package metrics import ( + "context" "testing" nbdns "github.com/netbirdio/netbird/dns" @@ -21,7 +22,7 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} { } // GetAllAccounts returns a list of *server.Account for use in tests with predefined information -func (mockDatasource) GetAllAccounts() []*server.Account { +func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { return []*server.Account{ { Id: "1", @@ -188,7 +189,7 @@ func TestGenerateProperties(t *testing.T) { connManager: ds, } - properties := worker.generateProperties() + properties := worker.generateProperties(context.Background()) if properties["accounts"] != 2 { t.Errorf("expected 2 accounts, got %d", properties["accounts"]) diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index 8a2d4219e..4c8baea5e 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -1,6 +1,7 @@ package migration import ( + "context" "database/sql" "encoding/gob" "encoding/json" @@ -16,7 +17,7 @@ import ( // MigrateFieldFromGobToJSON migrates a column from Gob encoding to JSON encoding. // T is the type of the model that contains the field to be migrated. // S is the type of the field to be migrated. -func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) error { +func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, fieldName string) error { oldColumnName := fieldName newColumnName := fieldName + "_tmp" @@ -24,7 +25,7 @@ func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) erro var model T if !db.Migrator().HasTable(&model) { - log.Debugf("Table for %T does not exist, no migration needed", model) + log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", model) return nil } @@ -38,7 +39,7 @@ func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) erro var sqliteItem sql.NullString if err := db.Model(model).Select(oldColumnName).First(&sqliteItem).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - log.Debugf("No records in table %s, no migration needed", tableName) + log.WithContext(ctx).Debugf("No records in table %s, no migration needed", tableName) return nil } return fmt.Errorf("fetch first record: %w", err) @@ -51,7 +52,7 @@ func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) erro err = json.Unmarshal([]byte(item), &js) // if the item is JSON parsable or an empty string it can not be gob encoded if err == nil || !errors.As(err, &syntaxError) || item == "" { - log.Debugf("No migration needed for %s, %s", tableName, fieldName) + log.WithContext(ctx).Debugf("No migration needed for %s, %s", tableName, fieldName) return nil } @@ -99,14 +100,14 @@ func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) erro return err } - log.Infof("Migration of %s.%s from gob to json completed", tableName, fieldName) + log.WithContext(ctx).Infof("Migration of %s.%s from gob to json completed", tableName, fieldName) return nil } // MigrateNetIPFieldFromBlobToJSON migrates a Net IP column from Blob encoding to JSON encoding. // T is the type of the model that contains the field to be migrated. -func MigrateNetIPFieldFromBlobToJSON[T any](db *gorm.DB, fieldName string, indexName string) error { +func MigrateNetIPFieldFromBlobToJSON[T any](ctx context.Context, db *gorm.DB, fieldName string, indexName string) error { oldColumnName := fieldName newColumnName := fieldName + "_tmp" @@ -138,7 +139,7 @@ func MigrateNetIPFieldFromBlobToJSON[T any](db *gorm.DB, fieldName string, index var syntaxError *json.SyntaxError err = json.Unmarshal([]byte(item.String), &js) if err == nil || !errors.As(err, &syntaxError) { - log.Debugf("No migration needed for %s, %s", tableName, fieldName) + log.WithContext(ctx).Debugf("No migration needed for %s, %s", tableName, fieldName) return nil } } @@ -169,7 +170,7 @@ func MigrateNetIPFieldFromBlobToJSON[T any](db *gorm.DB, fieldName string, index columnIpValue := net.IP(blobValue) if net.ParseIP(columnIpValue.String()) == nil { - log.Debugf("failed to parse %s as ip, fallback to ipv6 loopback", oldColumnName) + log.WithContext(ctx).Debugf("failed to parse %s as ip, fallback to ipv6 loopback", oldColumnName) columnIpValue = net.IPv6loopback } diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index 45757e9d6..5a1926641 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -1,6 +1,7 @@ package migration_test import ( + "context" "encoding/gob" "net" "strings" @@ -30,7 +31,7 @@ func setupDatabase(t *testing.T) *gorm.DB { func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) { db := setupDatabase(t) - err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net") + err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net") require.NoError(t, err, "Migration should not fail for an empty database") } @@ -63,7 +64,7 @@ func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) { err = gob.NewDecoder(strings.NewReader(gobStr)).Decode(&ipnet) require.NoError(t, err, "Failed to decode Gob data") - err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net") + err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net") require.NoError(t, err, "Migration should not fail with Gob data") var jsonStr string @@ -83,7 +84,7 @@ func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) { err = db.Save(&server.Account{Network: &server.Network{Net: *ipnet}}).Error require.NoError(t, err, "Failed to insert JSON data") - err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net") + err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net") require.NoError(t, err, "Migration should not fail with JSON data") var jsonStr string @@ -93,7 +94,7 @@ func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) { func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) { db := setupDatabase(t) - err := migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip") + err := migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "ip", "idx_peers_account_id_ip") require.NoError(t, err, "Migration should not fail for an empty database") } @@ -130,7 +131,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { err = db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&blobValue).Error assert.NoError(t, err, "Failed to fetch blob data") - err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "") + err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "") require.NoError(t, err, "Migration should not fail with net.IP blob data") var jsonStr string @@ -152,7 +153,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) { ).Error require.NoError(t, err, "Failed to insert JSON data") - err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "") + err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "") require.NoError(t, err, "Migration should not fail with net.IP JSON data") var jsonStr string diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 669fab861..177088ac5 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -1,6 +1,7 @@ package mock_server import ( + "context" "net" "net/netip" "time" @@ -21,94 +22,95 @@ import ( ) type MockAccountManager struct { - GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) - CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, + GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error) + CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) - GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) - GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) - GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error) - ListUsersFunc func(accountID string) ([]*server.User, error) - GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error) - MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) - DeletePeerFunc func(accountID, peerKey, userID string) error - GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) - GetPeerNetworkFunc func(peerKey string) (*server.Network, error) - AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) - GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error) - GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error) - GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error) - SaveGroupFunc func(accountID, userID string, group *group.Group) error - DeleteGroupFunc func(accountID, userId, groupID string) error - ListGroupsFunc func(accountID string) ([]*group.Group, error) - GroupAddPeerFunc func(accountID, groupID, peerID string) error - GroupDeletePeerFunc func(accountID, groupID, peerID string) error - DeleteRuleFunc func(accountID, ruleID, userID string) error - GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error) - SavePolicyFunc func(accountID, userID string, policy *server.Policy) error - DeletePolicyFunc func(accountID, policyID, userID string) error - ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error) - GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error) - GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) - MarkPATUsedFunc func(pat string) error - UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error - UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error - UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - CreateRouteFunc func(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) - GetRouteFunc func(accountID string, routeID route.ID, userID string) (*route.Route, error) - SaveRouteFunc func(accountID string, userID string, route *route.Route) error - DeleteRouteFunc func(accountID string, routeID route.ID, userID string) error - ListRoutesFunc func(accountID, userID string) ([]*route.Route, error) - SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) - ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) - SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error) - SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) - DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error - CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) - DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error - GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) - GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) - GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) - CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) - SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error - DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error - ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error) - CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) - GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) - CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error - DeleteAccountFunc func(accountID, userID string) error + GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) + GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error) + GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) + ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) + GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) + MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error + SyncAndMarkPeerFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error + GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error) + GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error) + AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*group.Group, error) + GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error) + GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error) + SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error + DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error + ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error) + GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error + GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error + DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error + GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) error + DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error + ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) + GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) + GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) + MarkPATUsedFunc func(ctx context.Context, pat string) error + UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error + UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error + UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) + CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) + SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error + DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error + ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error) + SaveSetupKeyFunc func(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) + ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) + SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) + SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) + DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error + CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) + DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error + GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) + GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) + GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) + CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) + SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error + DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error + ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) + CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) + GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) + CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error + DeleteAccountFunc func(ctx context.Context, accountID, userID string) error GetDNSDomainFunc func() string - StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) - GetEventsFunc func(accountID, userID string) ([]*activity.Event, error) - GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error) - SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error - GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) - LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) - SyncPeerFunc func(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) - InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error + StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) + GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error) + GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error) + SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error + GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) + UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) + LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error GetAllConnectedPeersFunc func() (map[string]struct{}, error) HasConnectedChannelFunc func(peerID string) bool GetExternalCacheManagerFunc func() server.ExternalCacheManager - GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error - DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error - ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error) + GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) + SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error + ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManagerFunc func() idp.Manager - UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error - GroupValidationFunc func(accountId string, groups []string) (bool, error) - SyncPeerMetaFunc func(peerPubKey string, meta nbpeer.PeerSystemMeta) error + UpdateIntegratedValidatorGroupsFunc func(ctx context.Context, accountID string, userID string, groups []string) error + GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error) + SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) + GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) } -func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { if am.SyncAndMarkPeerFunc != nil { - return am.SyncAndMarkPeerFunc(peerPubKey, meta, realIP) + return am.SyncAndMarkPeerFunc(ctx, peerPubKey, meta, realIP) } return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -func (am *MockAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error { +func (am *MockAccountManager) CancelPeerRoutines(_ context.Context, peer *nbpeer.Peer) error { // TODO implement me panic("implement me") } @@ -122,43 +124,43 @@ func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[st } // GetGroup mock implementation of GetGroup from server.AccountManager interface -func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*group.Group, error) { +func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, userID string) (*group.Group, error) { if am.GetGroupFunc != nil { - return am.GetGroupFunc(accountId, groupID, userID) + return am.GetGroupFunc(ctx, accountId, groupID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetGroup is not implemented") } // GetAllGroups mock implementation of GetAllGroups from server.AccountManager interface -func (am *MockAccountManager) GetAllGroups(accountID, userID string) ([]*group.Group, error) { +func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*group.Group, error) { if am.GetAllGroupsFunc != nil { - return am.GetAllGroupsFunc(accountID, userID) + return am.GetAllGroupsFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetAllGroups is not implemented") } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface -func (am *MockAccountManager) GetUsersFromAccount(accountID string, userID string) ([]*server.UserInfo, error) { +func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*server.UserInfo, error) { if am.GetUsersFromAccountFunc != nil { - return am.GetUsersFromAccountFunc(accountID, userID) + return am.GetUsersFromAccountFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount is not implemented") } // DeletePeer mock implementation of DeletePeer from server.AccountManager interface -func (am *MockAccountManager) DeletePeer(accountID, peerID, userID string) error { +func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error { if am.DeletePeerFunc != nil { - return am.DeletePeerFunc(accountID, peerID, userID) + return am.DeletePeerFunc(ctx, accountID, peerID, userID) } return status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented") } // GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface func (am *MockAccountManager) GetOrCreateAccountByUser( - userId, domain string, + ctx context.Context, userId, domain string, ) (*server.Account, error) { if am.GetOrCreateAccountByUserFunc != nil { - return am.GetOrCreateAccountByUserFunc(userId, domain) + return am.GetOrCreateAccountByUserFunc(ctx, userId, domain) } return nil, status.Errorf( codes.Unimplemented, @@ -168,6 +170,7 @@ func (am *MockAccountManager) GetOrCreateAccountByUser( // CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface func (am *MockAccountManager) CreateSetupKey( + ctx context.Context, accountID string, keyName string, keyType server.SetupKeyType, @@ -178,17 +181,17 @@ func (am *MockAccountManager) CreateSetupKey( ephemeral bool, ) (*server.SetupKey, error) { if am.CreateSetupKeyFunc != nil { - return am.CreateSetupKeyFunc(accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID, ephemeral) + return am.CreateSetupKeyFunc(ctx, accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID, ephemeral) } return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") } // GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface func (am *MockAccountManager) GetAccountByUserOrAccountID( - userId, accountId, domain string, + ctx context.Context, userId, accountId, domain string, ) (*server.Account, error) { if am.GetAccountByUserOrAccountIdFunc != nil { - return am.GetAccountByUserOrAccountIdFunc(userId, accountId, domain) + return am.GetAccountByUserOrAccountIdFunc(ctx, userId, accountId, domain) } return nil, status.Errorf( codes.Unimplemented, @@ -197,391 +200,392 @@ func (am *MockAccountManager) GetAccountByUserOrAccountID( } // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface -func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *server.Account) error { +func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *server.Account) error { if am.MarkPeerConnectedFunc != nil { - return am.MarkPeerConnectedFunc(peerKey, connected, realIP) + return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP) } return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } // GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface -func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { if am.GetAccountFromPATFunc != nil { - return am.GetAccountFromPATFunc(pat) + return am.GetAccountFromPATFunc(ctx, pat) } return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") } // DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface -func (am *MockAccountManager) DeleteAccount(accountID, userID string) error { +func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error { if am.DeleteAccountFunc != nil { - return am.DeleteAccountFunc(accountID, userID) + return am.DeleteAccountFunc(ctx, accountID, userID) } return status.Errorf(codes.Unimplemented, "method DeleteAccount is not implemented") } // MarkPATUsed mock implementation of MarkPATUsed from server.AccountManager interface -func (am *MockAccountManager) MarkPATUsed(pat string) error { +func (am *MockAccountManager) MarkPATUsed(ctx context.Context, pat string) error { if am.MarkPATUsedFunc != nil { - return am.MarkPATUsedFunc(pat) + return am.MarkPATUsedFunc(ctx, pat) } return status.Errorf(codes.Unimplemented, "method MarkPATUsed is not implemented") } // CreatePAT mock implementation of GetPAT from server.AccountManager interface -func (am *MockAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { +func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { if am.CreatePATFunc != nil { - return am.CreatePATFunc(accountID, initiatorUserID, targetUserID, name, expiresIn) + return am.CreatePATFunc(ctx, accountID, initiatorUserID, targetUserID, name, expiresIn) } return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented") } // DeletePAT mock implementation of DeletePAT from server.AccountManager interface -func (am *MockAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error { +func (am *MockAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { if am.DeletePATFunc != nil { - return am.DeletePATFunc(accountID, initiatorUserID, targetUserID, tokenID) + return am.DeletePATFunc(ctx, accountID, initiatorUserID, targetUserID, tokenID) } return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented") } // GetPAT mock implementation of GetPAT from server.AccountManager interface -func (am *MockAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { if am.GetPATFunc != nil { - return am.GetPATFunc(accountID, initiatorUserID, targetUserID, tokenID) + return am.GetPATFunc(ctx, accountID, initiatorUserID, targetUserID, tokenID) } return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented") } // GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface -func (am *MockAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { if am.GetAllPATsFunc != nil { - return am.GetAllPATsFunc(accountID, initiatorUserID, targetUserID) + return am.GetAllPATsFunc(ctx, accountID, initiatorUserID, targetUserID) } return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented") } // GetNetworkMap mock implementation of GetNetworkMap from server.AccountManager interface -func (am *MockAccountManager) GetNetworkMap(peerKey string) (*server.NetworkMap, error) { +func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) (*server.NetworkMap, error) { if am.GetNetworkMapFunc != nil { - return am.GetNetworkMapFunc(peerKey) + return am.GetNetworkMapFunc(ctx, peerKey) } return nil, status.Errorf(codes.Unimplemented, "method GetNetworkMap is not implemented") } // GetPeerNetwork mock implementation of GetPeerNetwork from server.AccountManager interface -func (am *MockAccountManager) GetPeerNetwork(peerKey string) (*server.Network, error) { +func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string) (*server.Network, error) { if am.GetPeerNetworkFunc != nil { - return am.GetPeerNetworkFunc(peerKey) + return am.GetPeerNetworkFunc(ctx, peerKey) } return nil, status.Errorf(codes.Unimplemented, "method GetPeerNetwork is not implemented") } // AddPeer mock implementation of AddPeer from server.AccountManager interface func (am *MockAccountManager) AddPeer( + ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer, ) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { if am.AddPeerFunc != nil { - return am.AddPeerFunc(setupKey, userId, peer) + return am.AddPeerFunc(ctx, setupKey, userId, peer) } return nil, nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented") } // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface -func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*group.Group, error) { +func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*group.Group, error) { if am.GetGroupFunc != nil { - return am.GetGroupByNameFunc(accountID, groupName) + return am.GetGroupByNameFunc(ctx, accountID, groupName) } return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented") } // SaveGroup mock implementation of SaveGroup from server.AccountManager interface -func (am *MockAccountManager) SaveGroup(accountID, userID string, group *group.Group) error { +func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID string, group *group.Group) error { if am.SaveGroupFunc != nil { - return am.SaveGroupFunc(accountID, userID, group) + return am.SaveGroupFunc(ctx, accountID, userID, group) } return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented") } // DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface -func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) error { +func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { if am.DeleteGroupFunc != nil { - return am.DeleteGroupFunc(accountId, userId, groupID) + return am.DeleteGroupFunc(ctx, accountId, userId, groupID) } return status.Errorf(codes.Unimplemented, "method DeleteGroup is not implemented") } // ListGroups mock implementation of ListGroups from server.AccountManager interface -func (am *MockAccountManager) ListGroups(accountID string) ([]*group.Group, error) { +func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) { if am.ListGroupsFunc != nil { - return am.ListGroupsFunc(accountID) + return am.ListGroupsFunc(ctx, accountID) } return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented") } // GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface -func (am *MockAccountManager) GroupAddPeer(accountID, groupID, peerID string) error { +func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { if am.GroupAddPeerFunc != nil { - return am.GroupAddPeerFunc(accountID, groupID, peerID) + return am.GroupAddPeerFunc(ctx, accountID, groupID, peerID) } return status.Errorf(codes.Unimplemented, "method GroupAddPeer is not implemented") } // GroupDeletePeer mock implementation of GroupDeletePeer from server.AccountManager interface -func (am *MockAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error { +func (am *MockAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { if am.GroupDeletePeerFunc != nil { - return am.GroupDeletePeerFunc(accountID, groupID, peerID) + return am.GroupDeletePeerFunc(ctx, accountID, groupID, peerID) } return status.Errorf(codes.Unimplemented, "method GroupDeletePeer is not implemented") } // DeleteRule mock implementation of DeleteRule from server.AccountManager interface -func (am *MockAccountManager) DeleteRule(accountID, ruleID, userID string) error { +func (am *MockAccountManager) DeleteRule(ctx context.Context, accountID, ruleID, userID string) error { if am.DeleteRuleFunc != nil { - return am.DeleteRuleFunc(accountID, ruleID, userID) + return am.DeleteRuleFunc(ctx, accountID, ruleID, userID) } return status.Errorf(codes.Unimplemented, "method DeleteRule is not implemented") } // GetPolicy mock implementation of GetPolicy from server.AccountManager interface -func (am *MockAccountManager) GetPolicy(accountID, policyID, userID string) (*server.Policy, error) { +func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) { if am.GetPolicyFunc != nil { - return am.GetPolicyFunc(accountID, policyID, userID) + return am.GetPolicyFunc(ctx, accountID, policyID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetPolicy is not implemented") } // SavePolicy mock implementation of SavePolicy from server.AccountManager interface -func (am *MockAccountManager) SavePolicy(accountID, userID string, policy *server.Policy) error { +func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) error { if am.SavePolicyFunc != nil { - return am.SavePolicyFunc(accountID, userID, policy) + return am.SavePolicyFunc(ctx, accountID, userID, policy) } return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") } // DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface -func (am *MockAccountManager) DeletePolicy(accountID, policyID, userID string) error { +func (am *MockAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { if am.DeletePolicyFunc != nil { - return am.DeletePolicyFunc(accountID, policyID, userID) + return am.DeletePolicyFunc(ctx, accountID, policyID, userID) } return status.Errorf(codes.Unimplemented, "method DeletePolicy is not implemented") } // ListPolicies mock implementation of ListPolicies from server.AccountManager interface -func (am *MockAccountManager) ListPolicies(accountID, userID string) ([]*server.Policy, error) { +func (am *MockAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*server.Policy, error) { if am.ListPoliciesFunc != nil { - return am.ListPoliciesFunc(accountID, userID) + return am.ListPoliciesFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method ListPolicies is not implemented") } // UpdatePeerMeta mock implementation of UpdatePeerMeta from server.AccountManager interface -func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta nbpeer.PeerSystemMeta) error { +func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error { if am.UpdatePeerMetaFunc != nil { - return am.UpdatePeerMetaFunc(peerID, meta) + return am.UpdatePeerMetaFunc(ctx, peerID, meta) } return status.Errorf(codes.Unimplemented, "method UpdatePeerMeta is not implemented") } // GetUser mock implementation of GetUser from server.AccountManager interface -func (am *MockAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*server.User, error) { +func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) { if am.GetUserFunc != nil { - return am.GetUserFunc(claims) + return am.GetUserFunc(ctx, claims) } return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented") } -func (am *MockAccountManager) ListUsers(accountID string) ([]*server.User, error) { +func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*server.User, error) { if am.ListUsersFunc != nil { - return am.ListUsersFunc(accountID) + return am.ListUsersFunc(ctx, accountID) } return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented") } // UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager -func (am *MockAccountManager) UpdatePeerSSHKey(peerID string, sshKey string) error { +func (am *MockAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error { if am.UpdatePeerSSHKeyFunc != nil { - return am.UpdatePeerSSHKeyFunc(peerID, sshKey) + return am.UpdatePeerSSHKeyFunc(ctx, peerID, sshKey) } return status.Errorf(codes.Unimplemented, "method UpdatePeerSSHKey is not implemented") } // UpdatePeer mocks UpdatePeerFunc function of the account manager -func (am *MockAccountManager) UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) { +func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) { if am.UpdatePeerFunc != nil { - return am.UpdatePeerFunc(accountID, userID, peer) + return am.UpdatePeerFunc(ctx, accountID, userID, peer) } return nil, status.Errorf(codes.Unimplemented, "method UpdatePeer is not implemented") } // CreateRoute mock implementation of CreateRoute from server.AccountManager interface -func (am *MockAccountManager) CreateRoute(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { +func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { if am.CreateRouteFunc != nil { - return am.CreateRouteFunc(accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID, keepRoute) + return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID, keepRoute) } return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") } // GetRoute mock implementation of GetRoute from server.AccountManager interface -func (am *MockAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) { +func (am *MockAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { if am.GetRouteFunc != nil { - return am.GetRouteFunc(accountID, routeID, userID) + return am.GetRouteFunc(ctx, accountID, routeID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetRoute is not implemented") } // SaveRoute mock implementation of SaveRoute from server.AccountManager interface -func (am *MockAccountManager) SaveRoute(accountID string, userID string, route *route.Route) error { +func (am *MockAccountManager) SaveRoute(ctx context.Context, accountID string, userID string, route *route.Route) error { if am.SaveRouteFunc != nil { - return am.SaveRouteFunc(accountID, userID, route) + return am.SaveRouteFunc(ctx, accountID, userID, route) } return status.Errorf(codes.Unimplemented, "method SaveRoute is not implemented") } // DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface -func (am *MockAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error { +func (am *MockAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { if am.DeleteRouteFunc != nil { - return am.DeleteRouteFunc(accountID, routeID, userID) + return am.DeleteRouteFunc(ctx, accountID, routeID, userID) } return status.Errorf(codes.Unimplemented, "method DeleteRoute is not implemented") } // ListRoutes mock implementation of ListRoutes from server.AccountManager interface -func (am *MockAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) { +func (am *MockAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { if am.ListRoutesFunc != nil { - return am.ListRoutesFunc(accountID, userID) + return am.ListRoutesFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method ListRoutes is not implemented") } // SaveSetupKey mocks SaveSetupKey of the AccountManager interface -func (am *MockAccountManager) SaveSetupKey(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) { +func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) { if am.SaveSetupKeyFunc != nil { - return am.SaveSetupKeyFunc(accountID, key, userID) + return am.SaveSetupKeyFunc(ctx, accountID, key, userID) } return nil, status.Errorf(codes.Unimplemented, "method SaveSetupKey is not implemented") } // GetSetupKey mocks GetSetupKey of the AccountManager interface -func (am *MockAccountManager) GetSetupKey(accountID, userID, keyID string) (*server.SetupKey, error) { +func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) { if am.GetSetupKeyFunc != nil { - return am.GetSetupKeyFunc(accountID, userID, keyID) + return am.GetSetupKeyFunc(ctx, accountID, userID, keyID) } return nil, status.Errorf(codes.Unimplemented, "method GetSetupKey is not implemented") } // ListSetupKeys mocks ListSetupKeys of the AccountManager interface -func (am *MockAccountManager) ListSetupKeys(accountID, userID string) ([]*server.SetupKey, error) { +func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) { if am.ListSetupKeysFunc != nil { - return am.ListSetupKeysFunc(accountID, userID) + return am.ListSetupKeysFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method ListSetupKeys is not implemented") } // SaveUser mocks SaveUser of the AccountManager interface -func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.User) (*server.UserInfo, error) { +func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) { if am.SaveUserFunc != nil { - return am.SaveUserFunc(accountID, userID, user) + return am.SaveUserFunc(ctx, accountID, userID, user) } return nil, status.Errorf(codes.Unimplemented, "method SaveUser is not implemented") } // SaveOrAddUser mocks SaveOrAddUser of the AccountManager interface -func (am *MockAccountManager) SaveOrAddUser(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) { +func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) { if am.SaveOrAddUserFunc != nil { - return am.SaveOrAddUserFunc(accountID, userID, user, addIfNotExists) + return am.SaveOrAddUserFunc(ctx, accountID, userID, user, addIfNotExists) } return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented") } // DeleteUser mocks DeleteUser of the AccountManager interface -func (am *MockAccountManager) DeleteUser(accountID string, initiatorUserID string, targetUserID string) error { +func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error { if am.DeleteUserFunc != nil { - return am.DeleteUserFunc(accountID, initiatorUserID, targetUserID) + return am.DeleteUserFunc(ctx, accountID, initiatorUserID, targetUserID) } return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented") } -func (am *MockAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error { +func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error { if am.InviteUserFunc != nil { - return am.InviteUserFunc(accountID, initiatorUserID, targetUserID) + return am.InviteUserFunc(ctx, accountID, initiatorUserID, targetUserID) } return status.Errorf(codes.Unimplemented, "method InviteUser is not implemented") } // GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface -func (am *MockAccountManager) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { +func (am *MockAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { if am.GetNameServerGroupFunc != nil { - return am.GetNameServerGroupFunc(accountID, userID, nsGroupID) + return am.GetNameServerGroupFunc(ctx, accountID, userID, nsGroupID) } return nil, nil } // CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface -func (am *MockAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) { +func (am *MockAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) { if am.CreateNameServerGroupFunc != nil { - return am.CreateNameServerGroupFunc(accountID, name, description, nameServerList, groups, primary, domains, enabled, userID, searchDomainsEnabled) + return am.CreateNameServerGroupFunc(ctx, accountID, name, description, nameServerList, groups, primary, domains, enabled, userID, searchDomainsEnabled) } return nil, nil } // SaveNameServerGroup mocks SaveNameServerGroup of the AccountManager interface -func (am *MockAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { +func (am *MockAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { if am.SaveNameServerGroupFunc != nil { - return am.SaveNameServerGroupFunc(accountID, userID, nsGroupToSave) + return am.SaveNameServerGroupFunc(ctx, accountID, userID, nsGroupToSave) } return nil } // DeleteNameServerGroup mocks DeleteNameServerGroup of the AccountManager interface -func (am *MockAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error { +func (am *MockAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error { if am.DeleteNameServerGroupFunc != nil { - return am.DeleteNameServerGroupFunc(accountID, nsGroupID, userID) + return am.DeleteNameServerGroupFunc(ctx, accountID, nsGroupID, userID) } return nil } // ListNameServerGroups mocks ListNameServerGroups of the AccountManager interface -func (am *MockAccountManager) ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) { +func (am *MockAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) { if am.ListNameServerGroupsFunc != nil { - return am.ListNameServerGroupsFunc(accountID, userID) + return am.ListNameServerGroupsFunc(ctx, accountID, userID) } return nil, nil } // CreateUser mocks CreateUser of the AccountManager interface -func (am *MockAccountManager) CreateUser(accountID, userID string, invite *server.UserInfo) (*server.UserInfo, error) { +func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID string, invite *server.UserInfo) (*server.UserInfo, error) { if am.CreateUserFunc != nil { - return am.CreateUserFunc(accountID, userID, invite) + return am.CreateUserFunc(ctx, accountID, userID, invite) } return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented") } // GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface -func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, +func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error, ) { if am.GetAccountFromTokenFunc != nil { - return am.GetAccountFromTokenFunc(claims) + return am.GetAccountFromTokenFunc(ctx, claims) } return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented") } -func (am *MockAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error { +func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { if am.CheckUserAccessByJWTGroupsFunc != nil { - return am.CheckUserAccessByJWTGroupsFunc(claims) + return am.CheckUserAccessByJWTGroupsFunc(ctx, claims) } return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented") } // GetPeers mocks GetPeers of the AccountManager interface -func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) { +func (am *MockAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { if am.GetPeersFunc != nil { - return am.GetPeersFunc(accountID, userID) + return am.GetPeersFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetPeers is not implemented") } @@ -595,57 +599,57 @@ func (am *MockAccountManager) GetDNSDomain() string { } // GetEvents mocks GetEvents of the AccountManager interface -func (am *MockAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) { +func (am *MockAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) { if am.GetEventsFunc != nil { - return am.GetEventsFunc(accountID, userID) + return am.GetEventsFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetEvents is not implemented") } // GetDNSSettings mocks GetDNSSettings of the AccountManager interface -func (am *MockAccountManager) GetDNSSettings(accountID string, userID string) (*server.DNSSettings, error) { +func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) { if am.GetDNSSettingsFunc != nil { - return am.GetDNSSettingsFunc(accountID, userID) + return am.GetDNSSettingsFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetDNSSettings is not implemented") } // SaveDNSSettings mocks SaveDNSSettings of the AccountManager interface -func (am *MockAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { +func (am *MockAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { if am.SaveDNSSettingsFunc != nil { - return am.SaveDNSSettingsFunc(accountID, userID, dnsSettingsToSave) + return am.SaveDNSSettingsFunc(ctx, accountID, userID, dnsSettingsToSave) } return status.Errorf(codes.Unimplemented, "method SaveDNSSettings is not implemented") } // GetPeer mocks GetPeer of the AccountManager interface -func (am *MockAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) { +func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { if am.GetPeerFunc != nil { - return am.GetPeerFunc(accountID, peerID, userID) + return am.GetPeerFunc(ctx, accountID, peerID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetPeer is not implemented") } // UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface -func (am *MockAccountManager) UpdateAccountSettings(accountID, userID string, newSettings *server.Settings) (*server.Account, error) { +func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { if am.UpdateAccountSettingsFunc != nil { - return am.UpdateAccountSettingsFunc(accountID, userID, newSettings) + return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings) } return nil, status.Errorf(codes.Unimplemented, "method UpdateAccountSettings is not implemented") } // LoginPeer mocks LoginPeer of the AccountManager interface -func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { if am.LoginPeerFunc != nil { - return am.LoginPeerFunc(login) + return am.LoginPeerFunc(ctx, login) } return nil, nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented") } // SyncPeer mocks SyncPeer of the AccountManager interface -func (am *MockAccountManager) SyncPeer(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { if am.SyncPeerFunc != nil { - return am.SyncPeerFunc(sync, account) + return am.SyncPeerFunc(ctx, sync, account) } return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") } @@ -667,9 +671,9 @@ func (am *MockAccountManager) HasConnectedChannel(peerID string) bool { } // StoreEvent mocks StoreEvent of the AccountManager interface -func (am *MockAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { +func (am *MockAccountManager) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { if am.StoreEventFunc != nil { - am.StoreEventFunc(initiatorID, targetID, accountID, activityID, meta) + am.StoreEventFunc(ctx, initiatorID, targetID, accountID, activityID, meta) } } @@ -682,35 +686,35 @@ func (am *MockAccountManager) GetExternalCacheManager() server.ExternalCacheMana } // GetPostureChecks mocks GetPostureChecks of the AccountManager interface -func (am *MockAccountManager) GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) { +func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { if am.GetPostureChecksFunc != nil { - return am.GetPostureChecksFunc(accountID, postureChecksID, userID) + return am.GetPostureChecksFunc(ctx, accountID, postureChecksID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetPostureChecks is not implemented") } // SavePostureChecks mocks SavePostureChecks of the AccountManager interface -func (am *MockAccountManager) SavePostureChecks(accountID, userID string, postureChecks *posture.Checks) error { +func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { if am.SavePostureChecksFunc != nil { - return am.SavePostureChecksFunc(accountID, userID, postureChecks) + return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks) } return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") } // DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface -func (am *MockAccountManager) DeletePostureChecks(accountID, postureChecksID, userID string) error { +func (am *MockAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { if am.DeletePostureChecksFunc != nil { - return am.DeletePostureChecksFunc(accountID, postureChecksID, userID) + return am.DeletePostureChecksFunc(ctx, accountID, postureChecksID, userID) } return status.Errorf(codes.Unimplemented, "method DeletePostureChecks is not implemented") } // ListPostureChecks mocks ListPostureChecks of the AccountManager interface -func (am *MockAccountManager) ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) { +func (am *MockAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { if am.ListPostureChecksFunc != nil { - return am.ListPostureChecksFunc(accountID, userID) + return am.ListPostureChecksFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method ListPostureChecks is not implemented") } @@ -724,25 +728,25 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager { } // UpdateIntegratedValidatorGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface -func (am *MockAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error { +func (am *MockAccountManager) UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error { if am.UpdateIntegratedValidatorGroupsFunc != nil { - return am.UpdateIntegratedValidatorGroupsFunc(accountID, userID, groups) + return am.UpdateIntegratedValidatorGroupsFunc(ctx, accountID, userID, groups) } return status.Errorf(codes.Unimplemented, "method UpdateIntegratedValidatorGroups is not implemented") } // GroupValidation mocks GroupValidation of the AccountManager interface -func (am *MockAccountManager) GroupValidation(accountId string, groups []string) (bool, error) { +func (am *MockAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) { if am.GroupValidationFunc != nil { - return am.GroupValidationFunc(accountId, groups) + return am.GroupValidationFunc(ctx, accountId, groups) } return false, status.Errorf(codes.Unimplemented, "method GroupValidation is not implemented") } // SyncPeerMeta mocks SyncPeerMeta of the AccountManager interface -func (am *MockAccountManager) SyncPeerMeta(peerPubKey string, meta nbpeer.PeerSystemMeta) error { +func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error { if am.SyncPeerMetaFunc != nil { - return am.SyncPeerMetaFunc(peerPubKey, meta) + return am.SyncPeerMetaFunc(ctx, peerPubKey, meta) } return status.Errorf(codes.Unimplemented, "method SyncPeerMeta is not implemented") } @@ -754,3 +758,11 @@ func (am *MockAccountManager) FindExistingPostureCheck(accountID string, checks } return nil, status.Errorf(codes.Unimplemented, "method FindExistingPostureCheck is not implemented") } + +// GetAccountIDForPeerKey mocks GetAccountIDForPeerKey of the AccountManager interface +func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) { + if am.GetAccountIDForPeerKeyFunc != nil { + return am.GetAccountIDForPeerKeyFunc(ctx, peerKey) + } + return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented") +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 44d231c3e..f8d644ded 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -1,6 +1,7 @@ package server import ( + "context" "errors" "regexp" "unicode/utf8" @@ -17,12 +18,12 @@ import ( const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs -func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { +func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -45,12 +46,12 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID } // CreateNameServerGroup creates and saves a new nameserver group -func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { +func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -79,29 +80,29 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d account.NameServerGroups[newNSGroup.ID] = newNSGroup account.Network.IncSerial() - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) - am.StoreEvent(userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) + am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) return newNSGroup.Copy(), nil } // SaveNameServerGroup saves nameserver group -func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { +func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() if nsGroupToSave == nil { return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") } - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -114,25 +115,25 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave account.Network.IncSerial() - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) - am.StoreEvent(userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) + am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) return nil } // DeleteNameServerGroup deletes nameserver group with nsGroupID -func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error { +func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -144,25 +145,25 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, use delete(account.NameServerGroups, nsGroupID) account.Network.IncSerial() - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) - am.StoreEvent(userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) + am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) return nil } // ListNameServerGroups returns a list of nameserver groups from account -func (am *DefaultAccountManager) ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) { +func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 4e07943b3..dd7935fee 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "net/netip" "testing" @@ -383,6 +384,7 @@ func TestCreateNameServerGroup(t *testing.T) { } outNSGroup, err := am.CreateNameServerGroup( + context.Background(), account.Id, testCase.inputArgs.name, testCase.inputArgs.description, @@ -611,7 +613,7 @@ func TestSaveNameServerGroup(t *testing.T) { account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(context.Background(), account) if err != nil { t.Error("account should be saved") } @@ -646,7 +648,7 @@ func TestSaveNameServerGroup(t *testing.T) { } } - err = am.SaveNameServerGroup(account.Id, userID, nsGroupToSave) + err = am.SaveNameServerGroup(context.Background(), account.Id, userID, nsGroupToSave) testCase.errFunc(t, err) @@ -654,7 +656,7 @@ func TestSaveNameServerGroup(t *testing.T) { return } - account, err = am.Store.GetAccount(account.Id) + account, err = am.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Fatal(err) } @@ -705,17 +707,17 @@ func TestDeleteNameServerGroup(t *testing.T) { account.NameServerGroups[testingNSGroup.ID] = testingNSGroup - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(context.Background(), account) if err != nil { t.Error("failed to save account") } - err = am.DeleteNameServerGroup(account.Id, testingNSGroup.ID, userID) + err = am.DeleteNameServerGroup(context.Background(), account.Id, testingNSGroup.ID, userID) if err != nil { t.Error("deleting nameserver group failed with error: ", err) } - savedAccount, err := am.Store.GetAccount(account.Id) + savedAccount, err := am.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Error("failed to retrieve saved account with error: ", err) } @@ -738,7 +740,7 @@ func TestGetNameServerGroup(t *testing.T) { t.Error("failed to init testing account") } - foundGroup, err := am.GetNameServerGroup(account.Id, testUserID, existingNSGroupID) + foundGroup, err := am.GetNameServerGroup(context.Background(), account.Id, testUserID, existingNSGroupID) if err != nil { t.Error("getting existing nameserver group failed with error: ", err) } @@ -747,7 +749,7 @@ func TestGetNameServerGroup(t *testing.T) { t.Error("got a nil group while getting nameserver group with ID") } - _, err = am.GetNameServerGroup(account.Id, testUserID, "not existing") + _, err = am.GetNameServerGroup(context.Background(), account.Id, testUserID, "not existing") if err == nil { t.Error("getting not existing nameserver group should return error, got nil") } @@ -760,13 +762,13 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) + return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) } func createNSStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(dataDir) + store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) if err != nil { return nil, err } @@ -829,7 +831,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error userID := testUserID domain := "example.com" - account := newAccountWithId(accountID, userID, domain) + account := newAccountWithId(context.Background(), accountID, userID, domain) account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup @@ -846,16 +848,16 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error account.Groups[newGroup1.ID] = newGroup1 account.Groups[newGroup2.ID] = newGroup2 - err := am.Store.SaveAccount(account) + err := am.Store.SaveAccount(context.Background(), account) if err != nil { return nil, err } - _, _, _, err = am.AddPeer("", userID, peer1) + _, _, _, err = am.AddPeer(context.Background(), "", userID, peer1) if err != nil { return nil, err } - _, _, _, err = am.AddPeer("", userID, peer2) + _, _, _, err = am.AddPeer(context.Background(), "", userID, peer2) if err != nil { return nil, err } diff --git a/management/server/peer.go b/management/server/peer.go index fa482eec0..b8605fbb7 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "net" "strings" @@ -45,8 +46,8 @@ type PeerLogin struct { // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // the current user is not an admin. -func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) { - account, err := am.Store.GetAccount(accountID) +func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -79,7 +80,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P // fetch all the peers that have access to the user's peers for _, peer := range peers { - aclPeers, _ := account.getPeerConnectionResources(peer.ID, approvedPeersMap) + aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap) for _, p := range aclPeers { peersMap[p.ID] = p } @@ -94,7 +95,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P } // MarkPeerConnected marks peer as connected (true) or disconnected (false) -func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP, account *Account) error { +func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error { peer, err := account.FindPeerByPubKey(peerPubKey) if err != nil { return err @@ -113,7 +114,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected if am.geo != nil && realIP != nil { location, err := am.geo.Lookup(realIP) if err != nil { - log.Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err) + log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err) } else { peer.Location.ConnectionIP = realIP peer.Location.CountryCode = location.Country.ISOCode @@ -121,7 +122,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected peer.Location.GeoNameID = location.City.GeonameID err = am.Store.SavePeerLocation(account.Id, peer) if err != nil { - log.Warnf("could not store location for peer %s: %s", peer.ID, err) + log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) } } } @@ -134,24 +135,24 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected } if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(account) + am.checkAndSchedulePeerLoginExpiration(ctx, account) } if oldStatus.LoginExpired { // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) } return nil } // UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated. -func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -161,7 +162,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID) } - update, err = am.integratedPeerValidator.ValidatePeer(update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + update, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra) if err != nil { return nil, err } @@ -172,7 +173,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb if !update.SSHEnabled { event = activity.PeerSSHDisabled } - am.StoreEvent(userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) } if peer.Name != update.Name { @@ -187,7 +188,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb peer.DNSLabel = newLabel - am.StoreEvent(userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain())) } if peer.LoginExpirationEnabled != update.LoginExpirationEnabled { @@ -202,27 +203,27 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb if !update.LoginExpirationEnabled { event = activity.PeerLoginExpirationDisabled } - am.StoreEvent(userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(account) + am.checkAndSchedulePeerLoginExpiration(ctx, account) } } account.UpdatePeer(peer) - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return peer, nil } // deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock -func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string, userID string) error { +func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Account, peerIDs []string, userID string) error { // the first loop is needed to ensure all peers present under the account before modifying, otherwise // we might have some inconsistencies @@ -239,13 +240,13 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string, // the 2nd loop performs the actual modification for _, peer := range peers { - err := am.integratedPeerValidator.PeerDeleted(account.Id, peer.ID) + err := am.integratedPeerValidator.PeerDeleted(ctx, account.Id, peer.ID) if err != nil { return err } account.DeletePeer(peer.ID) - am.peersUpdateManager.SendUpdate(peer.ID, + am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{ Update: &proto.SyncResponse{ // fill those field for backward compatibility @@ -261,41 +262,41 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string, }, }, }) - am.peersUpdateManager.CloseChannel(peer.ID) - am.StoreEvent(userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) + am.peersUpdateManager.CloseChannel(ctx, peer.ID) + am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) } return nil } // DeletePeer removes peer from the account by its IP -func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } - err = am.deletePeers(account, []string{peerID}, userID) + err = am.deletePeers(ctx, account, []string{peerID}, userID) if err != nil { return err } - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) -func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, error) { - account, err := am.Store.GetAccountByPeerID(peerID) +func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) { + account, err := am.Store.GetAccountByPeerID(ctx, peerID) if err != nil { return nil, err } @@ -314,12 +315,12 @@ func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, erro if err != nil { return nil, err } - return account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validatedPeers), nil + return account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validatedPeers), nil } // GetPeerNetwork returns the Network for a given peer -func (am *DefaultAccountManager) GetPeerNetwork(peerID string) (*Network, error) { - account, err := am.Store.GetAccountByPeerID(peerID) +func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID string) (*Network, error) { + account, err := am.Store.GetAccountByPeerID(ctx, peerID) if err != nil { return nil, err } @@ -334,7 +335,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(peerID string) (*Network, error) // to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied // Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused). // The peer property is just a placeholder for the Peer properties to pass further -func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { if setupKey == "" && userID == "" { // no auth method provided => reject access return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login") @@ -348,13 +349,13 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P addedByUser = true accountID, err = am.Store.GetAccountIDByUserID(userID) } else { - accountID, err = am.Store.GetAccountIDBySetupKey(setupKey) + accountID, err = am.Store.GetAccountIDBySetupKey(ctx, setupKey) } if err != nil { return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found") } - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer func() { if unlock != nil { unlock() @@ -363,14 +364,14 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P var account *Account // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) - account, err = am.Store.GetAccount(accountID) + account, err = am.Store.GetAccount(ctx, accountID) if err != nil { return nil, nil, nil, err } if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" { if am.idpManager != nil { - userdata, err := am.lookupUserInCache(userID, account) + userdata, err := am.lookupUserInCache(ctx, userID, account) if err == nil && userdata != nil { peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0]) } @@ -479,7 +480,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P } } - newPeer = am.integratedPeerValidator.PreparePeer(account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra) + newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra) if addedByUser { user, err := account.FindUser(userID) @@ -491,7 +492,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P account.Peers[newPeer.ID] = newPeer account.Network.IncSerial() - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, nil, nil, err } @@ -506,9 +507,9 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P opEvent.Meta["setup_key_name"] = setupKeyName } - am.StoreEvent(opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) + am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) approvedPeersMap, err := am.GetValidatedPeers(account) if err != nil { @@ -516,12 +517,12 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P } postureChecks := am.getPeerPostureChecks(account, peer) - networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain, approvedPeersMap) + networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, am.dnsDomain, approvedPeersMap) return newPeer, networkMap, postureChecks, nil } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible -func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) if err != nil { return nil, nil, nil, status.NewPeerNotRegisteredError() @@ -532,23 +533,23 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp return nil, nil, nil, err } - if peerLoginExpired(peer, account.Settings) { + if peerLoginExpired(ctx, peer, account.Settings) { return nil, nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") } peer, updated := updatePeerMeta(peer, sync.Meta, account) if updated { - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, nil, nil, err } if sync.UpdateAccountPeers { - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) } } - peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) if err != nil { return nil, nil, nil, err } @@ -563,7 +564,7 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp } if isStatusChanged { - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) } validPeersMap, err := am.GetValidatedPeers(account) @@ -572,13 +573,13 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp } postureChecks = am.getPeerPostureChecks(account, peer) - return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validPeersMap), postureChecks, nil + return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validPeersMap), postureChecks, nil } // LoginPeer logs in or registers a peer. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. -func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { - accountID, err := am.Store.GetAccountIDByPeerPubKey(login.WireGuardPubKey) +func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { + accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey) if err != nil { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. @@ -591,7 +592,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw if am.geo != nil && login.ConnectionIP != nil { location, err := am.geo.Lookup(login.ConnectionIP) if err != nil { - log.Warnf("failed to get location for new peer realip: [%s]: %v", login.ConnectionIP.String(), err) + log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", login.ConnectionIP.String(), err) } else { newPeer.Location.ConnectionIP = login.ConnectionIP newPeer.Location.CountryCode = location.Country.ISOCode @@ -601,19 +602,19 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw } } - return am.AddPeer(login.SetupKey, login.UserID, newPeer) + return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer) } - log.Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err) + log.WithContext(ctx).Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err) return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer") } - peer, err := am.Store.GetPeerByPeerPubKey(login.WireGuardPubKey) + peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey) if err != nil { return nil, nil, nil, status.NewPeerNotRegisteredError() } - accSettings, err := am.Store.GetAccountSettings(accountID) + accSettings, err := am.Store.GetAccountSettings(ctx, accountID) if err != nil { return nil, nil, nil, status.Errorf(status.Internal, "failed to get account settings: %s", err) } @@ -621,30 +622,30 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw var isWriteLock bool // duplicated logic from after the lock to have an early exit - expired := peerLoginExpired(peer, accSettings) + expired := peerLoginExpired(ctx, peer, accSettings) switch { case expired: - if err := checkAuth(login.UserID, peer); err != nil { + if err := checkAuth(ctx, login.UserID, peer); err != nil { return nil, nil, nil, err } isWriteLock = true - log.Debugf("peer login expired, acquiring write lock") + log.WithContext(ctx).Debugf("peer login expired, acquiring write lock") case peer.UpdateMetaIfNew(login.Meta): isWriteLock = true - log.Debugf("peer changed meta, acquiring write lock") + log.WithContext(ctx).Debugf("peer changed meta, acquiring write lock") default: isWriteLock = false - log.Debugf("peer meta is the same, acquiring read lock") + log.WithContext(ctx).Debugf("peer meta is the same, acquiring read lock") } var unlock func() if isWriteLock { - unlock = am.Store.AcquireAccountWriteLock(accountID) + unlock = am.Store.AcquireAccountWriteLock(ctx, accountID) } else { - unlock = am.Store.AcquireAccountReadLock(accountID) + unlock = am.Store.AcquireAccountReadLock(ctx, accountID) } defer func() { if unlock != nil { @@ -653,7 +654,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw }() // fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, nil, nil, err } @@ -671,8 +672,8 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw // this flag prevents unnecessary calls to the persistent store. shouldStoreAccount := false updateRemotePeers := false - if peerLoginExpired(peer, account.Settings) { - err = checkAuth(login.UserID, peer) + if peerLoginExpired(ctx, peer, account.Settings) { + err = checkAuth(ctx, login.UserID, peer) if err != nil { return nil, nil, nil, err } @@ -689,10 +690,10 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw } user.updateLastLogin(peer.LastLogin) - am.StoreEvent(login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) + am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) } - isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) if err != nil { return nil, nil, nil, err } @@ -701,17 +702,17 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw shouldStoreAccount = true } - peer, err = am.checkAndUpdatePeerSSHKey(peer, account, login.SSHKey) + peer, err = am.checkAndUpdatePeerSSHKey(ctx, peer, account, login.SSHKey) if err != nil { return nil, nil, nil, err } if shouldStoreAccount { if !isWriteLock { - log.Errorf("account %s should be stored but is not write locked", accountID) + log.WithContext(ctx).Errorf("account %s should be stored but is not write locked", accountID) return nil, nil, nil, status.Errorf(status.Internal, "account should be stored but is not write locked") } - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, nil, nil, err } @@ -720,7 +721,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw unlock = nil if updateRemotePeers || isStatusChanged { - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) } var postureChecks []*posture.Checks @@ -738,7 +739,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw } postureChecks = am.getPeerPostureChecks(account, peer) - return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil + return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil } func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error { @@ -754,23 +755,23 @@ func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error { return nil } -func checkAuth(loginUserID string, peer *nbpeer.Peer) error { +func checkAuth(ctx context.Context, loginUserID string, peer *nbpeer.Peer) error { if loginUserID == "" { // absence of a user ID indicates that JWT wasn't provided. return status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") } if peer.UserID != loginUserID { - log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID) + log.WithContext(ctx).Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID) return status.Errorf(status.Unauthenticated, "can't login") } return nil } -func peerLoginExpired(peer *nbpeer.Peer, settings *Settings) bool { +func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings) bool { expired, expiresIn := peer.LoginExpired(settings.PeerLoginExpiration) expired = settings.PeerLoginExpirationEnabled && expired if expired || peer.Status.LoginExpired { - log.Debugf("peer's %s login expired %v ago", peer.ID, expiresIn) + log.WithContext(ctx).Debugf("peer's %s login expired %v ago", peer.ID, expiresIn) return true } return false @@ -781,48 +782,48 @@ func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) { account.UpdatePeer(peer) } -func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(peer *nbpeer.Peer, account *Account, newSSHKey string) (*nbpeer.Peer, error) { +func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(ctx context.Context, peer *nbpeer.Peer, account *Account, newSSHKey string) (*nbpeer.Peer, error) { if len(newSSHKey) == 0 { - log.Debugf("no new SSH key provided for peer %s, skipping update", peer.ID) + log.WithContext(ctx).Debugf("no new SSH key provided for peer %s, skipping update", peer.ID) return peer, nil } if peer.SSHKey == newSSHKey { - log.Debugf("same SSH key provided for peer %s, skipping update", peer.ID) + log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peer.ID) return peer, nil } peer.SSHKey = newSSHKey account.UpdatePeer(peer) - err := am.Store.SaveAccount(account) + err := am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } // trigger network map update - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return peer, nil } // UpdatePeerSSHKey updates peer's public SSH key -func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string) error { +func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error { if sshKey == "" { - log.Debugf("empty SSH key provided for peer %s, skipping update", peerID) + log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID) return nil } - account, err := am.Store.GetAccountByPeerID(peerID) + account, err := am.Store.GetAccountByPeerID(ctx, peerID) if err != nil { return err } - unlock := am.Store.AcquireAccountWriteLock(account.Id) + unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id) defer unlock() // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) - account, err = am.Store.GetAccount(account.Id) + account, err = am.Store.GetAccount(ctx, account.Id) if err != nil { return err } @@ -833,30 +834,30 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string) } if peer.SSHKey == sshKey { - log.Debugf("same SSH key provided for peer %s, skipping update", peerID) + log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID) return nil } peer.SSHKey = sshKey account.UpdatePeer(peer) - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return err } // trigger network map update - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } // GetPeer for a given accountID, peerID and userID error if not found. -func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -893,7 +894,7 @@ func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbp } for _, p := range userPeers { - aclPeers, _ := account.getPeerConnectionResources(p.ID, approvedPeersMap) + aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap) for _, aclPeer := range aclPeers { if aclPeer.ID == peerID { return peer, nil @@ -914,23 +915,23 @@ func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Acco // updateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (am *DefaultAccountManager) updateAccountPeers(account *Account) { +func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) { peers := account.GetPeers() approvedPeersMap, err := am.GetValidatedPeers(account) if err != nil { - log.Errorf("failed send out updates to peers, failed to validate peer: %v", err) + log.WithContext(ctx).Errorf("failed send out updates to peers, failed to validate peer: %v", err) return } for _, peer := range peers { if !am.peersUpdateManager.HasChannel(peer.ID) { - log.Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) + log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) continue } postureChecks := am.getPeerPostureChecks(account, peer) - remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap) - update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks) - am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update}) + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap) + update := toSyncResponse(ctx, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks) + am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update}) } } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index c5305cf5b..407877296 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "testing" "time" @@ -80,7 +81,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) if err != nil { t.Fatal("error creating setup key") return @@ -92,7 +93,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, }) @@ -106,7 +107,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { t.Fatal(err) return } - _, _, _, err = manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, }) @@ -116,7 +117,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { return } - networkMap, err := manager.GetNetworkMap(peer1.ID) + networkMap, err := manager.GetNetworkMap(context.Background(), peer1.ID) if err != nil { t.Fatal(err) return @@ -165,7 +166,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, }) @@ -179,7 +180,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { t.Fatal(err) return } - peer2, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer2, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, }) @@ -188,13 +189,13 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - policies, err := manager.ListPolicies(account.Id, userID) + policies, err := manager.ListPolicies(context.Background(), account.Id, userID) if err != nil { t.Errorf("expecting to get a list of rules, got failure %v", err) return } - err = manager.DeletePolicy(account.Id, policies[0].ID, userID) + err = manager.DeletePolicy(context.Background(), account.Id, policies[0].ID, userID) if err != nil { t.Errorf("expecting to delete 1 group, got failure %v", err) return @@ -213,12 +214,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { group1.Peers = append(group1.Peers, peer1.ID) group2.Peers = append(group2.Peers, peer2.ID) - err = manager.SaveGroup(account.Id, userID, &group1) + err = manager.SaveGroup(context.Background(), account.Id, userID, &group1) if err != nil { t.Errorf("expecting group1 to be added, got failure %v", err) return } - err = manager.SaveGroup(account.Id, userID, &group2) + err = manager.SaveGroup(context.Background(), account.Id, userID, &group2) if err != nil { t.Errorf("expecting group2 to be added, got failure %v", err) return @@ -235,13 +236,13 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { Action: PolicyTrafficActionAccept, }, } - err = manager.SavePolicy(account.Id, userID, &policy) + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return } - networkMap1, err := manager.GetNetworkMap(peer1.ID) + networkMap1, err := manager.GetNetworkMap(context.Background(), peer1.ID) if err != nil { t.Fatal(err) return @@ -264,7 +265,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { ) } - networkMap2, err := manager.GetNetworkMap(peer2.ID) + networkMap2, err := manager.GetNetworkMap(context.Background(), peer2.ID) if err != nil { t.Fatal(err) return @@ -283,13 +284,13 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { } policy.Enabled = false - err = manager.SavePolicy(account.Id, userID, &policy) + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return } - networkMap1, err = manager.GetNetworkMap(peer1.ID) + networkMap1, err = manager.GetNetworkMap(context.Background(), peer1.ID) if err != nil { t.Fatal(err) return @@ -304,7 +305,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - networkMap2, err = manager.GetNetworkMap(peer2.ID) + networkMap2, err = manager.GetNetworkMap(context.Background(), peer2.ID) if err != nil { t.Fatal(err) return @@ -329,7 +330,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) if err != nil { t.Fatal("error creating setup key") return @@ -341,7 +342,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, }) @@ -355,7 +356,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { t.Fatal(err) return } - _, _, _, err = manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, }) @@ -365,7 +366,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { return } - network, err := manager.GetPeerNetwork(peer1.ID) + network, err := manager.GetPeerNetwork(context.Background(), peer1.ID) if err != nil { t.Fatal(err) return @@ -387,21 +388,21 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { accountID := "test_account" adminUser := "account_creator" someUser := "some_user" - account := newAccountWithId(accountID, adminUser, "") + account := newAccountWithId(context.Background(), accountID, adminUser, "") account.Users[someUser] = &User{ Id: someUser, Role: UserRoleUser, } account.Settings.RegularUsersViewBlocked = false - err = manager.Store.SaveAccount(account) + err = manager.Store.SaveAccount(context.Background(), account) if err != nil { t.Fatal(err) return } // two peers one added by a regular user and one with a setup key - setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) if err != nil { t.Fatal("error creating setup key") return @@ -413,7 +414,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer("", someUser, &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", someUser, &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, }) @@ -429,7 +430,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } // the second peer added with a setup key - peer2, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{ + peer2, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, }) @@ -439,7 +440,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } // the user can see its own peer - peer, err := manager.GetPeer(accountID, peer1.ID, someUser) + peer, err := manager.GetPeer(context.Background(), accountID, peer1.ID, someUser) if err != nil { t.Fatal(err) return @@ -447,7 +448,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { assert.NotNil(t, peer) // the user can see peer2 because peer1 of the user has access to peer2 due to the All group and the default rule 0 all-to-all access - peer, err = manager.GetPeer(accountID, peer2.ID, someUser) + peer, err = manager.GetPeer(context.Background(), accountID, peer2.ID, someUser) if err != nil { t.Fatal(err) return @@ -456,7 +457,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { // delete the all-to-all policy so that user's peer1 has no access to peer2 for _, policy := range account.Policies { - err = manager.DeletePolicy(accountID, policy.ID, adminUser) + err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser) if err != nil { t.Fatal(err) return @@ -464,18 +465,18 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } // at this point the user can't see the details of peer2 - peer, err = manager.GetPeer(accountID, peer2.ID, someUser) //nolint + peer, err = manager.GetPeer(context.Background(), accountID, peer2.ID, someUser) //nolint assert.Error(t, err) // admin users can always access all the peers - peer, err = manager.GetPeer(accountID, peer1.ID, adminUser) + peer, err = manager.GetPeer(context.Background(), accountID, peer1.ID, adminUser) if err != nil { t.Fatal(err) return } assert.NotNil(t, peer) - peer, err = manager.GetPeer(accountID, peer2.ID, adminUser) + peer, err = manager.GetPeer(context.Background(), accountID, peer2.ID, adminUser) if err != nil { t.Fatal(err) return @@ -574,7 +575,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { accountID := "test_account" adminUser := "account_creator" someUser := "some_user" - account := newAccountWithId(accountID, adminUser, "") + account := newAccountWithId(context.Background(), accountID, adminUser, "") account.Users[someUser] = &User{ Id: someUser, Role: testCase.role, @@ -583,7 +584,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { account.Policies = []*Policy{} account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings - err = manager.Store.SaveAccount(account) + err = manager.Store.SaveAccount(context.Background(), account) if err != nil { t.Fatal(err) return @@ -601,7 +602,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { return } - _, _, _, err = manager.AddPeer("", someUser, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", someUser, &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, }) @@ -610,7 +611,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { return } - _, _, _, err = manager.AddPeer("", adminUser, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", adminUser, &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, }) @@ -619,7 +620,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { return } - peers, err := manager.GetPeers(accountID, someUser) + peers, err := manager.GetPeers(context.Background(), accountID, someUser) if err != nil { t.Fatal(err) return diff --git a/management/server/policy.go b/management/server/policy.go index 5206df9e9..a70d7f0ed 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -1,6 +1,7 @@ package server import ( + "context" _ "embed" "strconv" "strings" @@ -211,9 +212,9 @@ type FirewallRule struct { // getPeerConnectionResources for a given peer // // This function returns the list of peers and firewall rules that are applicable to a given peer. -func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { +func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { - generateResources, getAccumulatedResources := a.connResourcesGenerator() + generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) for _, policy := range a.Policies { if !policy.Enabled { continue @@ -224,8 +225,8 @@ func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap ma continue } - sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) - destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID, nil, validatedPeersMap) + sourcePeers, peerInSources := getAllPeersFromGroups(ctx, a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) + destinationPeers, peerInDestinations := getAllPeersFromGroups(ctx, a, rule.Destinations, peerID, nil, validatedPeersMap) if rule.Bidirectional { if peerInSources { @@ -254,7 +255,7 @@ func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap ma // The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. // It safe to call the generator function multiple times for same peer and different rules no duplicates will be // generated. The accumulator function returns the result of all the generator calls. -func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { +func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { rulesExists := make(map[string]struct{}) peersExists := make(map[string]struct{}) rules := make([]*FirewallRule, 0) @@ -262,7 +263,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in all, err := a.GetGroupAll() if err != nil { - log.Errorf("failed to get group all: %v", err) + log.WithContext(ctx).Errorf("failed to get group all: %v", err) all = &nbgroup.Group{} } @@ -313,11 +314,11 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in } // GetPolicy from the store -func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (*Policy, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -341,11 +342,11 @@ func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) ( } // SavePolicy in the store -func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Policy) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -353,7 +354,7 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po exists := am.savePolicy(account, policy) account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } @@ -361,19 +362,19 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po if exists { action = activity.PolicyUpdated } - am.StoreEvent(userID, policy.ID, accountID, action, policy.EventMeta()) + am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } // DeletePolicy from the store -func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -384,23 +385,23 @@ func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string } account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.StoreEvent(userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) + am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } // ListPolicies from the store -func (am *DefaultAccountManager) ListPolicies(accountID, userID string) ([]*Policy, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -490,7 +491,7 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { // // Important: Posture checks are applicable only to source group peers, // for destination group peers, call this method with an empty list of sourcePostureChecksIDs -func getAllPeersFromGroups(account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { +func getAllPeersFromGroups(ctx context.Context, account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { peerInGroups := false filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) for _, g := range groups { @@ -506,7 +507,7 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string, sou } // validate the peer based on policy posture checks applied - isValid := account.validatePostureChecksOnPeer(sourcePostureChecksIDs, peer.ID) + isValid := account.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) if !isValid { continue } @@ -527,7 +528,7 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string, sou } // validatePostureChecksOnPeer validates the posture checks on a peer -func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, peerID string) bool { +func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool { peer, ok := a.Peers[peerID] if !ok && peer == nil { return false @@ -540,9 +541,9 @@ func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, pe } for _, check := range postureChecks.GetChecks() { - isValid, err := check.Check(*peer) + isValid, err := check.Check(ctx, *peer) if err != nil { - log.Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error()) + log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error()) } if !isValid { return false diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 1ea3bb379..bf9a53d16 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "net" "testing" @@ -143,14 +144,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) { for _, p := range account.Peers { - peers, firewallRules := account.getPeerConnectionResources(p.ID, validatedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), p.ID, validatedPeers) assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") } }) t.Run("check first peer map details", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerB", validatedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", validatedPeers) assert.Len(t, peers, 7) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) @@ -386,7 +387,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { } t.Run("check first peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) epectedFirewallRules := []*FirewallRule{ @@ -414,7 +415,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerC", approvedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) epectedFirewallRules := []*FirewallRule{ @@ -444,7 +445,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { account.Policies[1].Rules[0].Bidirectional = false t.Run("check first peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) epectedFirewallRules := []*FirewallRule{ @@ -465,7 +466,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerC", approvedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) epectedFirewallRules := []*FirewallRule{ @@ -662,7 +663,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) { // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // will establish a connection with all source peers satisfying the NB posture check. - peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -672,7 +673,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources("peerC", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, 1) expectedFirewallRules := []*FirewallRule{ @@ -688,7 +689,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources("peerE", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -698,7 +699,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources("peerI", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -713,19 +714,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) + peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules = account.getPeerConnectionResources("peerI", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources("peerC", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) @@ -740,14 +741,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources("peerE", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers) assert.Len(t, peers, 3) assert.Len(t, firewallRules, 3) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerD"]) - peers, firewallRules = account.getPeerConnectionResources("peerA", approvedPeers) + peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerA", approvedPeers) assert.Len(t, peers, 5) // assert peers from Group Swarm assert.Contains(t, peers, account.Peers["peerD"]) diff --git a/management/server/posture/checks.go b/management/server/posture/checks.go index 3df5beacf..f2739dddf 100644 --- a/management/server/posture/checks.go +++ b/management/server/posture/checks.go @@ -1,6 +1,7 @@ package posture import ( + "context" "errors" "net/netip" "regexp" @@ -31,7 +32,7 @@ var ( // Check represents an interface for performing a check on a peer. type Check interface { Name() string - Check(peer nbpeer.Peer) (bool, error) + Check(ctx context.Context, peer nbpeer.Peer) (bool, error) Validate() error } diff --git a/management/server/posture/geo_location.go b/management/server/posture/geo_location.go index b51f80519..8a1f38830 100644 --- a/management/server/posture/geo_location.go +++ b/management/server/posture/geo_location.go @@ -1,6 +1,7 @@ package posture import ( + "context" "fmt" "slices" @@ -25,7 +26,7 @@ type GeoLocationCheck struct { Action string } -func (g *GeoLocationCheck) Check(peer nbpeer.Peer) (bool, error) { +func (g *GeoLocationCheck) Check(_ context.Context, peer nbpeer.Peer) (bool, error) { // deny if the peer location is not evaluated if peer.Location.CountryCode == "" && peer.Location.CityName == "" { return false, fmt.Errorf("peer's location is not set") diff --git a/management/server/posture/geo_location_test.go b/management/server/posture/geo_location_test.go index a92732c53..a64919f0b 100644 --- a/management/server/posture/geo_location_test.go +++ b/management/server/posture/geo_location_test.go @@ -1,6 +1,7 @@ package posture import ( + "context" "testing" "github.com/netbirdio/netbird/management/server/peer" @@ -226,7 +227,7 @@ func TestGeoLocationCheck_Check(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - isValid, err := tt.check.Check(tt.input) + isValid, err := tt.check.Check(context.Background(), tt.input) if tt.wantErr { assert.Error(t, err) } else { diff --git a/management/server/posture/nb_version.go b/management/server/posture/nb_version.go index 62a6e268c..f63db85b1 100644 --- a/management/server/posture/nb_version.go +++ b/management/server/posture/nb_version.go @@ -1,6 +1,7 @@ package posture import ( + "context" "fmt" "github.com/hashicorp/go-version" @@ -15,7 +16,7 @@ type NBVersionCheck struct { var _ Check = (*NBVersionCheck)(nil) -func (n *NBVersionCheck) Check(peer nbpeer.Peer) (bool, error) { +func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) { peerNBVersion, err := version.NewVersion(peer.Meta.WtVersion) if err != nil { return false, err @@ -30,7 +31,7 @@ func (n *NBVersionCheck) Check(peer nbpeer.Peer) (bool, error) { return true, nil } - log.Debugf("peer %s NB version %s is older than minimum allowed version %s", + log.WithContext(ctx).Debugf("peer %s NB version %s is older than minimum allowed version %s", peer.ID, peer.Meta.WtVersion, n.MinVersion) return false, nil diff --git a/management/server/posture/nb_version_test.go b/management/server/posture/nb_version_test.go index fbe24aa16..1bf485453 100644 --- a/management/server/posture/nb_version_test.go +++ b/management/server/posture/nb_version_test.go @@ -1,6 +1,7 @@ package posture import ( + "context" "testing" "github.com/netbirdio/netbird/management/server/peer" @@ -98,7 +99,7 @@ func TestNBVersionCheck_Check(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - isValid, err := tt.check.Check(tt.input) + isValid, err := tt.check.Check(context.Background(), tt.input) if tt.wantErr { assert.Error(t, err) } else { diff --git a/management/server/posture/network.go b/management/server/posture/network.go index ed90ea91b..0fa6f6e71 100644 --- a/management/server/posture/network.go +++ b/management/server/posture/network.go @@ -1,6 +1,7 @@ package posture import ( + "context" "fmt" "net/netip" "slices" @@ -16,7 +17,7 @@ type PeerNetworkRangeCheck struct { var _ Check = (*PeerNetworkRangeCheck)(nil) -func (p *PeerNetworkRangeCheck) Check(peer nbpeer.Peer) (bool, error) { +func (p *PeerNetworkRangeCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) { if len(peer.Meta.NetworkAddresses) == 0 { return false, fmt.Errorf("peer's does not contain peer network range addresses") } diff --git a/management/server/posture/network_test.go b/management/server/posture/network_test.go index 6242ece99..a841bbe08 100644 --- a/management/server/posture/network_test.go +++ b/management/server/posture/network_test.go @@ -1,6 +1,7 @@ package posture import ( + "context" "net/netip" "testing" @@ -137,7 +138,7 @@ func TestPeerNetworkRangeCheck_Check(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - isValid, err := tt.check.Check(tt.peer) + isValid, err := tt.check.Check(context.Background(), tt.peer) if tt.wantErr { assert.Error(t, err) } else { diff --git a/management/server/posture/os_version.go b/management/server/posture/os_version.go index e6f8ec367..411f4c2c6 100644 --- a/management/server/posture/os_version.go +++ b/management/server/posture/os_version.go @@ -1,6 +1,7 @@ package posture import ( + "context" "fmt" "strings" @@ -28,20 +29,20 @@ type OSVersionCheck struct { var _ Check = (*OSVersionCheck)(nil) -func (c *OSVersionCheck) Check(peer nbpeer.Peer) (bool, error) { +func (c *OSVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) { peerGoOS := peer.Meta.GoOS switch peerGoOS { case "android": - return checkMinVersion(peerGoOS, peer.Meta.OSVersion, c.Android) + return checkMinVersion(ctx, peerGoOS, peer.Meta.OSVersion, c.Android) case "darwin": - return checkMinVersion(peerGoOS, peer.Meta.OSVersion, c.Darwin) + return checkMinVersion(ctx, peerGoOS, peer.Meta.OSVersion, c.Darwin) case "ios": - return checkMinVersion(peerGoOS, peer.Meta.OSVersion, c.Ios) + return checkMinVersion(ctx, peerGoOS, peer.Meta.OSVersion, c.Ios) case "linux": kernelVersion := strings.Split(peer.Meta.KernelVersion, "-")[0] - return checkMinKernelVersion(peerGoOS, kernelVersion, c.Linux) + return checkMinKernelVersion(ctx, peerGoOS, kernelVersion, c.Linux) case "windows": - return checkMinKernelVersion(peerGoOS, peer.Meta.KernelVersion, c.Windows) + return checkMinKernelVersion(ctx, peerGoOS, peer.Meta.KernelVersion, c.Windows) } return true, nil } @@ -79,9 +80,9 @@ func (c *OSVersionCheck) Validate() error { return nil } -func checkMinVersion(peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) { +func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) { if check == nil { - log.Debugf("peer %s OS is not allowed in the check", peerGoOS) + log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS) return false, nil } @@ -99,14 +100,14 @@ func checkMinVersion(peerGoOS, peerVersion string, check *MinVersionCheck) (bool return true, nil } - log.Debugf("peer %s OS version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinVersion) + log.WithContext(ctx).Debugf("peer %s OS version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinVersion) return false, nil } -func checkMinKernelVersion(peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) { +func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) { if check == nil { - log.Debugf("peer %s OS is not allowed in the check", peerGoOS) + log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS) return false, nil } @@ -124,7 +125,7 @@ func checkMinKernelVersion(peerGoOS, peerVersion string, check *MinKernelVersion return true, nil } - log.Debugf("peer %s kernel version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinKernelVersion) + log.WithContext(ctx).Debugf("peer %s kernel version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinKernelVersion) return false, nil } diff --git a/management/server/posture/os_version_test.go b/management/server/posture/os_version_test.go index 845e703cf..76343b0c4 100644 --- a/management/server/posture/os_version_test.go +++ b/management/server/posture/os_version_test.go @@ -1,6 +1,7 @@ package posture import ( + "context" "testing" "github.com/netbirdio/netbird/management/server/peer" @@ -140,7 +141,7 @@ func TestOSVersionCheck_Check(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - isValid, err := tt.check.Check(tt.input) + isValid, err := tt.check.Check(context.Background(), tt.input) if tt.wantErr { assert.Error(t, err) } else { diff --git a/management/server/posture/process.go b/management/server/posture/process.go index cd3b23fcf..911aabb52 100644 --- a/management/server/posture/process.go +++ b/management/server/posture/process.go @@ -1,6 +1,7 @@ package posture import ( + "context" "fmt" "slices" @@ -19,7 +20,7 @@ type ProcessCheck struct { var _ Check = (*ProcessCheck)(nil) -func (p *ProcessCheck) Check(peer nbpeer.Peer) (bool, error) { +func (p *ProcessCheck) Check(_ context.Context, peer nbpeer.Peer) (bool, error) { peerActiveProcesses := extractPeerActiveProcesses(peer.Meta.Files) var pathSelector func(Process) string diff --git a/management/server/posture/process_test.go b/management/server/posture/process_test.go index 0bfaf4cb9..ce43a948a 100644 --- a/management/server/posture/process_test.go +++ b/management/server/posture/process_test.go @@ -1,6 +1,7 @@ package posture import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -233,7 +234,7 @@ func TestProcessCheck_Check(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - isValid, err := tt.check.Check(tt.input) + isValid, err := tt.check.Check(context.Background(), tt.input) if tt.wantErr { assert.Error(t, err) } else { diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index d525482b7..851d4d31f 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -1,6 +1,7 @@ package server import ( + "context" "slices" "github.com/netbirdio/netbird/management/server/activity" @@ -13,11 +14,11 @@ const ( errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks" ) -func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -40,11 +41,11 @@ func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, us return nil, status.Errorf(status.NotFound, "posture checks with ID %s not found", postureChecksID) } -func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, postureChecks *posture.Checks) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -75,23 +76,23 @@ func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, pos account.Network.IncSerial() } - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.StoreEvent(userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) + am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if exists { - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) } return nil } -func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID, userID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -110,20 +111,20 @@ func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID, return err } - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.StoreEvent(userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta()) + am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta()) return nil } -func (am *DefaultAccountManager) ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index dd92fe8b9..d837120f4 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -28,15 +29,15 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { t.Run("Generic posture check flow", func(t *testing.T) { // regular users can not create checks - err := am.SavePostureChecks(account.Id, regularUserID, &posture.Checks{}) + err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) assert.Error(t, err) // regular users cannot list check - _, err = am.ListPostureChecks(account.Id, regularUserID) + _, err = am.ListPostureChecks(context.Background(), account.Id, regularUserID) assert.Error(t, err) // should be possible to create posture check with uniq name - err = am.SavePostureChecks(account.Id, adminUserID, &posture.Checks{ + err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ ID: postureCheckID, Name: postureCheckName, Checks: posture.ChecksDefinition{ @@ -48,12 +49,12 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.NoError(t, err) // admin users can list check - checks, err := am.ListPostureChecks(account.Id, adminUserID) + checks, err := am.ListPostureChecks(context.Background(), account.Id, adminUserID) assert.NoError(t, err) assert.Len(t, checks, 1) // should not be possible to create posture check with non uniq name - err = am.SavePostureChecks(account.Id, adminUserID, &posture.Checks{ + err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ ID: "new-id", Name: postureCheckName, Checks: posture.ChecksDefinition{ @@ -69,7 +70,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Error(t, err) // admins can update posture checks - err = am.SavePostureChecks(account.Id, adminUserID, &posture.Checks{ + err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ ID: postureCheckID, Name: postureCheckName, Checks: posture.ChecksDefinition{ @@ -81,13 +82,13 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.NoError(t, err) // users should not be able to delete posture checks - err = am.DeletePostureChecks(account.Id, postureCheckID, regularUserID) + err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID) assert.Error(t, err) // admin should be able to delete posture checks - err = am.DeletePostureChecks(account.Id, postureCheckID, adminUserID) + err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID) assert.NoError(t, err) - checks, err = am.ListPostureChecks(account.Id, adminUserID) + checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID) assert.NoError(t, err) assert.Len(t, checks, 0) }) @@ -106,14 +107,14 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { Role: UserRoleUser, } - account := newAccountWithId(accountID, groupAdminUserID, domain) + account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain) account.Users[admin.Id] = admin account.Users[user.Id] = user - err := am.Store.SaveAccount(account) + err := am.Store.SaveAccount(context.Background(), account) if err != nil { return nil, err } - return am.Store.GetAccount(account.Id) + return am.Store.GetAccount(context.Background(), account.Id) } diff --git a/management/server/route.go b/management/server/route.go index 2fae8ab8d..6db00a255 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "net/netip" "unicode/utf8" @@ -15,11 +16,11 @@ import ( ) // GetRoute gets a route object from account and route IDs -func (am *DefaultAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -124,11 +125,11 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string { } // CreateRoute creates and saves a new route -func (am *DefaultAccountManager) CreateRoute(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -200,20 +201,20 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, prefix netip.Pref account.Routes[newRoute.ID] = &newRoute account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return nil, err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) - am.StoreEvent(userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) + am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) return &newRoute, nil } // SaveRoute saves route -func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave *route.Route) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() if routeToSave == nil { @@ -228,7 +229,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) } - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -269,23 +270,23 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave account.Routes[routeToSave.ID] = routeToSave account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) - am.StoreEvent(userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) + am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) return nil } // DeleteRoute deletes route with routeID -func (am *DefaultAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -297,23 +298,23 @@ func (am *DefaultAccountManager) DeleteRoute(accountID string, routeID route.ID, delete(account.Routes, routeID) account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.StoreEvent(userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) + am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } // ListRoutes returns a list of routes from account -func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } diff --git a/management/server/route_test.go b/management/server/route_test.go index 4fd1d7357..8b168a79f 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "net/netip" "testing" @@ -421,13 +422,13 @@ func TestCreateRoute(t *testing.T) { if testCase.createInitRoute { groupAll, errInit := account.GetGroupAll() require.NoError(t, errInit) - _, errInit = am.CreateRoute(account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false) + _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false) require.NoError(t, errInit) - _, errInit = am.CreateRoute(account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false) + _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false) require.NoError(t, errInit) } - outRoute, err := am.CreateRoute(account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) + outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) testCase.errFunc(t, err) @@ -925,7 +926,7 @@ func TestSaveRoute(t *testing.T) { account.Routes[testCase.existingRoute.ID] = testCase.existingRoute - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(context.Background(), account) if err != nil { t.Error("account should be saved") } @@ -963,7 +964,7 @@ func TestSaveRoute(t *testing.T) { } } - err = am.SaveRoute(account.Id, userID, routeToSave) + err = am.SaveRoute(context.Background(), account.Id, userID, routeToSave) testCase.errFunc(t, err) @@ -971,7 +972,7 @@ func TestSaveRoute(t *testing.T) { return } - account, err = am.Store.GetAccount(account.Id) + account, err = am.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Fatal(err) } @@ -1012,17 +1013,17 @@ func TestDeleteRoute(t *testing.T) { account.Routes[testingRoute.ID] = testingRoute - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(context.Background(), account) if err != nil { t.Error("failed to save account") } - err = am.DeleteRoute(account.Id, testingRoute.ID, userID) + err = am.DeleteRoute(context.Background(), account.Id, testingRoute.ID, userID) if err != nil { t.Error("deleting route failed with error: ", err) } - savedAccount, err := am.Store.GetAccount(account.Id) + savedAccount, err := am.Store.GetAccount(context.Background(), account.Id) if err != nil { t.Error("failed to retrieve saved account with error: ", err) } @@ -1056,27 +1057,27 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { t.Error("failed to init testing account") } - newAccountRoutes, err := am.GetNetworkMap(peer1ID) + newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - newRoute, err := am.CreateRoute(account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.Enabled, userID, baseRoute.KeepRoute) + newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.Enabled, userID, baseRoute.KeepRoute) require.NoError(t, err) require.Equal(t, newRoute.Enabled, true) - peer1Routes, err := am.GetNetworkMap(peer1ID) + peer1Routes, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) assert.Len(t, peer1Routes.Routes, 1, "HA route should have 1 server route") - peer2Routes, err := am.GetNetworkMap(peer2ID) + peer2Routes, err := am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) assert.Len(t, peer2Routes.Routes, 1, "HA route should have 1 server route") - peer4Routes, err := am.GetNetworkMap(peer4ID) + peer4Routes, err := am.GetNetworkMap(context.Background(), peer4ID) require.NoError(t, err) assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") - groups, err := am.ListGroups(account.Id) + groups, err := am.ListGroups(context.Background(), account.Id) require.NoError(t, err) var groupHA1, groupHA2 *nbgroup.Group for _, group := range groups { @@ -1088,35 +1089,35 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { } } - err = am.GroupDeletePeer(account.Id, groupHA1.ID, peer2ID) + err = am.GroupDeletePeer(context.Background(), account.Id, groupHA1.ID, peer2ID) require.NoError(t, err) - peer2RoutesAfterDelete, err := am.GetNetworkMap(peer2ID) + peer2RoutesAfterDelete, err := am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes") - err = am.GroupDeletePeer(account.Id, groupHA2.ID, peer4ID) + err = am.GroupDeletePeer(context.Background(), account.Id, groupHA2.ID, peer4ID) require.NoError(t, err) - peer2RoutesAfterDelete, err = am.GetNetworkMap(peer2ID) + peer2RoutesAfterDelete, err = am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route") - err = am.GroupAddPeer(account.Id, groupHA2.ID, peer4ID) + err = am.GroupAddPeer(context.Background(), account.Id, groupHA2.ID, peer4ID) require.NoError(t, err) - peer1RoutesAfterAdd, err := am.GetNetworkMap(peer1ID) + peer1RoutesAfterAdd, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) assert.Len(t, peer1RoutesAfterAdd.Routes, 1, "HA route should have more than 1 route") - peer2RoutesAfterAdd, err := am.GetNetworkMap(peer2ID) + peer2RoutesAfterAdd, err := am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes") - err = am.DeleteRoute(account.Id, newRoute.ID, userID) + err = am.DeleteRoute(context.Background(), account.Id, newRoute.ID, userID) require.NoError(t, err) - peer1DeletedRoute, err := am.GetNetworkMap(peer1ID) + peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) assert.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1") } @@ -1147,14 +1148,14 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { t.Error("failed to init testing account") } - newAccountRoutes, err := am.GetNetworkMap(peer1ID) + newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - createdRoute, err := am.CreateRoute(account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, false, userID, baseRoute.KeepRoute) + createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, false, userID, baseRoute.KeepRoute) require.NoError(t, err) - noDisabledRoutes, err := am.GetNetworkMap(peer1ID) + noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) require.Len(t, noDisabledRoutes.Routes, 0, "no routes for disabled routes") @@ -1165,22 +1166,22 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { expectedRoute := enabledRoute.Copy() expectedRoute.Peer = peer1Key - err = am.SaveRoute(account.Id, userID, enabledRoute) + err = am.SaveRoute(context.Background(), account.Id, userID, enabledRoute) require.NoError(t, err) - peer1Routes, err := am.GetNetworkMap(peer1ID) + peer1Routes, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) require.Len(t, peer1Routes.Routes, 1, "we should receive one route for peer1") require.True(t, expectedRoute.IsEqual(peer1Routes.Routes[0]), "received route should be equal") - peer2Routes, err := am.GetNetworkMap(peer2ID) + peer2Routes, err := am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) require.Len(t, peer2Routes.Routes, 0, "no routes for peers not in the distribution group") - err = am.GroupAddPeer(account.Id, routeGroup1, peer2ID) + err = am.GroupAddPeer(context.Background(), account.Id, routeGroup1, peer2ID) require.NoError(t, err) - peer2Routes, err = am.GetNetworkMap(peer2ID) + peer2Routes, err = am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) require.Len(t, peer2Routes.Routes, 1, "we should receive one route") require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group") @@ -1190,10 +1191,10 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { Name: "peer1 group", Peers: []string{peer1ID}, } - err = am.SaveGroup(account.Id, userID, newGroup) + err = am.SaveGroup(context.Background(), account.Id, userID, newGroup) require.NoError(t, err) - rules, err := am.ListPolicies(account.Id, "testingUser") + rules, err := am.ListPolicies(context.Background(), account.Id, "testingUser") require.NoError(t, err) defaultRule := rules[0] @@ -1203,24 +1204,24 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID} - err = am.SavePolicy(account.Id, userID, newPolicy) + err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy) require.NoError(t, err) - err = am.DeletePolicy(account.Id, defaultRule.ID, userID) + err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) require.NoError(t, err) - peer1GroupRoutes, err := am.GetNetworkMap(peer1ID) + peer1GroupRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) require.Len(t, peer1GroupRoutes.Routes, 1, "we should receive one route for peer1") - peer2GroupRoutes, err := am.GetNetworkMap(peer2ID) + peer2GroupRoutes, err := am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) require.Len(t, peer2GroupRoutes.Routes, 0, "we should not receive routes for peer2") - err = am.DeleteRoute(account.Id, enabledRoute.ID, userID) + err = am.DeleteRoute(context.Background(), account.Id, enabledRoute.ID, userID) require.NoError(t, err) - peer1DeletedRoute, err := am.GetNetworkMap(peer1ID) + peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) require.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1") } @@ -1232,13 +1233,13 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) + return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) } func createRouterStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(dataDir) + store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) if err != nil { return nil, err } @@ -1253,8 +1254,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er accountID := "testingAcc" domain := "example.com" - account := newAccountWithId(accountID, userID, domain) - err := am.Store.SaveAccount(account) + account := newAccountWithId(context.Background(), accountID, userID, domain) + err := am.Store.SaveAccount(context.Background(), account) if err != nil { return nil, err } @@ -1389,7 +1390,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } account.Peers[peer5.ID] = peer5 - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(context.Background(), account) if err != nil { return nil, err } @@ -1397,19 +1398,19 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er if err != nil { return nil, err } - err = am.GroupAddPeer(accountID, groupAll.ID, peer1ID) + err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer1ID) if err != nil { return nil, err } - err = am.GroupAddPeer(accountID, groupAll.ID, peer2ID) + err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer2ID) if err != nil { return nil, err } - err = am.GroupAddPeer(accountID, groupAll.ID, peer3ID) + err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer3ID) if err != nil { return nil, err } - err = am.GroupAddPeer(accountID, groupAll.ID, peer4ID) + err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer4ID) if err != nil { return nil, err } @@ -1448,11 +1449,11 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } for _, group := range newGroup { - err = am.SaveGroup(accountID, userID, group) + err = am.SaveGroup(context.Background(), accountID, userID, group) if err != nil { return nil, err } } - return am.Store.GetAccount(account.Id) + return am.Store.GetAccount(context.Background(), account.Id) } diff --git a/management/server/scheduler.go b/management/server/scheduler.go index 356348056..147b50fc6 100644 --- a/management/server/scheduler.go +++ b/management/server/scheduler.go @@ -1,6 +1,7 @@ package server import ( + "context" "sync" "time" @@ -9,32 +10,32 @@ import ( // Scheduler is an interface which implementations can schedule and cancel jobs type Scheduler interface { - Cancel(IDs []string) - Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) + Cancel(ctx context.Context, IDs []string) + Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) } // MockScheduler is a mock implementation of Scheduler type MockScheduler struct { - CancelFunc func(IDs []string) - ScheduleFunc func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) + CancelFunc func(ctx context.Context, IDs []string) + ScheduleFunc func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) } // Cancel mocks the Cancel function of the Scheduler interface -func (mock *MockScheduler) Cancel(IDs []string) { +func (mock *MockScheduler) Cancel(ctx context.Context, IDs []string) { if mock.CancelFunc != nil { - mock.CancelFunc(IDs) + mock.CancelFunc(ctx, IDs) return } - log.Errorf("MockScheduler doesn't have Cancel function defined ") + log.WithContext(ctx).Errorf("MockScheduler doesn't have Cancel function defined ") } // Schedule mocks the Schedule function of the Scheduler interface -func (mock *MockScheduler) Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { +func (mock *MockScheduler) Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { if mock.ScheduleFunc != nil { - mock.ScheduleFunc(in, ID, job) + mock.ScheduleFunc(ctx, in, ID, job) return } - log.Errorf("MockScheduler doesn't have Schedule function defined") + log.WithContext(ctx).Errorf("MockScheduler doesn't have Schedule function defined") } // DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them. @@ -52,35 +53,35 @@ func NewDefaultScheduler() *DefaultScheduler { } } -func (wm *DefaultScheduler) cancel(ID string) bool { +func (wm *DefaultScheduler) cancel(ctx context.Context, ID string) bool { cancel, ok := wm.jobs[ID] if ok { delete(wm.jobs, ID) close(cancel) - log.Debugf("cancelled scheduled job %s", ID) + log.WithContext(ctx).Debugf("cancelled scheduled job %s", ID) } return ok } // Cancel cancels the scheduled job by ID if present. // If job wasn't found the function returns false. -func (wm *DefaultScheduler) Cancel(IDs []string) { +func (wm *DefaultScheduler) Cancel(ctx context.Context, IDs []string) { wm.mu.Lock() defer wm.mu.Unlock() for _, id := range IDs { - wm.cancel(id) + wm.cancel(ctx, id) } } // Schedule a job to run in some time in the future. If job returns true then it will be scheduled one more time. // If job with the provided ID already exists, a new one won't be scheduled. -func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { +func (wm *DefaultScheduler) Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { wm.mu.Lock() defer wm.mu.Unlock() cancel := make(chan struct{}) if _, ok := wm.jobs[ID]; ok { - log.Debugf("couldn't schedule a job %s because it already exists. There are %d total jobs scheduled.", + log.WithContext(ctx).Debugf("couldn't schedule a job %s because it already exists. There are %d total jobs scheduled.", ID, len(wm.jobs)) return } @@ -88,25 +89,25 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne ticker := time.NewTicker(in) wm.jobs[ID] = cancel - log.Debugf("scheduled a job %s to run in %s. There are %d total jobs scheduled.", ID, in.String(), len(wm.jobs)) + log.WithContext(ctx).Debugf("scheduled a job %s to run in %s. There are %d total jobs scheduled.", ID, in.String(), len(wm.jobs)) go func() { for { select { case <-ticker.C: select { case <-cancel: - log.Debugf("scheduled job %s was canceled, stop timer", ID) + log.WithContext(ctx).Debugf("scheduled job %s was canceled, stop timer", ID) ticker.Stop() return default: - log.Debugf("time to do a scheduled job %s", ID) + log.WithContext(ctx).Debugf("time to do a scheduled job %s", ID) } runIn, reschedule := job() if !reschedule { wm.mu.Lock() defer wm.mu.Unlock() delete(wm.jobs, ID) - log.Debugf("job %s is not scheduled to run again", ID) + log.WithContext(ctx).Debugf("job %s is not scheduled to run again", ID) ticker.Stop() return } @@ -115,7 +116,7 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne ticker.Reset(runIn) } case <-cancel: - log.Debugf("job %s was canceled, stopping timer", ID) + log.WithContext(ctx).Debugf("job %s was canceled, stopping timer", ID) ticker.Stop() return } diff --git a/management/server/scheduler_test.go b/management/server/scheduler_test.go index 9dd73e269..7c287a554 100644 --- a/management/server/scheduler_test.go +++ b/management/server/scheduler_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "math/rand" "runtime" @@ -20,7 +21,7 @@ func TestScheduler_Performance(t *testing.T) { minMs := 50 for i := 0; i < n; i++ { millis := time.Duration(rand.Intn(maxMs-minMs)+minMs) * time.Millisecond - go scheduler.Schedule(millis, fmt.Sprintf("test-scheduler-job-%d", i), func() (nextRunIn time.Duration, reschedule bool) { + go scheduler.Schedule(context.Background(), millis, fmt.Sprintf("test-scheduler-job-%d", i), func() (nextRunIn time.Duration, reschedule bool) { time.Sleep(millis) wg.Done() return 0, false @@ -53,19 +54,19 @@ func TestScheduler_Cancel(t *testing.T) { sleepTime = 20 * time.Millisecond } - scheduler.Schedule(scheduletime, jobID1, func() (nextRunIn time.Duration, reschedule bool) { + scheduler.Schedule(context.Background(), scheduletime, jobID1, func() (nextRunIn time.Duration, reschedule bool) { tt := p[0] <-tChan t.Logf("job %s", tt) return scheduletime, true }) - scheduler.Schedule(scheduletime, jobID2, func() (nextRunIn time.Duration, reschedule bool) { + scheduler.Schedule(context.Background(), scheduletime, jobID2, func() (nextRunIn time.Duration, reschedule bool) { return scheduletime, true }) time.Sleep(sleepTime) assert.Len(t, scheduler.jobs, 2) - scheduler.Cancel([]string{jobID1}) + scheduler.Cancel(context.Background(), []string{jobID1}) close(tChan) p = []string{} time.Sleep(sleepTime) @@ -83,7 +84,7 @@ func TestScheduler_Schedule(t *testing.T) { wg.Done() return 0, false } - scheduler.Schedule(300*time.Millisecond, jobID, job) + scheduler.Schedule(context.Background(), 300*time.Millisecond, jobID, job) failed := waitTimeout(wg, time.Second) if failed { t.Fatal("timed out while waiting for test to finish") @@ -107,12 +108,12 @@ func TestScheduler_Schedule(t *testing.T) { return 0, false } - scheduler.Schedule(300*time.Millisecond, jobID, job) + scheduler.Schedule(context.Background(), 300*time.Millisecond, jobID, job) failed = waitTimeout(wg, time.Second) if failed { t.Fatal("timed out while waiting for test to finish") return } - scheduler.cancel(jobID) + scheduler.cancel(context.Background(), jobID) } diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 40b8ac457..dcaee357c 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -1,6 +1,7 @@ package server import ( + "context" "hash/fnv" "strconv" "strings" @@ -207,9 +208,9 @@ func Hash(s string) uint32 { // CreateSetupKey generates a new setup key with a given name, type, list of groups IDs to auto-assign to peers registered with this key, // and adds it to the specified account. A list of autoGroups IDs can be empty. -func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, +func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() keyDuration := DefaultSetupKeyDuration @@ -217,7 +218,7 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string keyDuration = expiresIn } - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -230,20 +231,20 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups, usageLimit, ephemeral) account.SetupKeys[setupKey.Key] = setupKey - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, status.Errorf(status.Internal, "failed adding account key") } - am.StoreEvent(userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta()) + am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.SetupKeyCreated, setupKey.EventMeta()) for _, g := range setupKey.AutoGroups { group := account.GetGroup(g) if group != nil { - am.StoreEvent(userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey, + am.StoreEvent(ctx, userID, setupKey.Id, accountID, activity.GroupAddedToSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": setupKey.Name}) } else { - log.Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) + log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) } } @@ -254,15 +255,15 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string // Due to the unique nature of a SetupKey certain properties must not be overwritten // (e.g. the key itself, creation date, ID, etc). // These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. -func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() if keyToSave == nil { return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") } - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -287,12 +288,12 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup account.SetupKeys[newKey.Key] = newKey - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return nil, err } if !oldKey.Revoked && newKey.Revoked { - am.StoreEvent(userID, newKey.Id, accountID, activity.SetupKeyRevoked, newKey.EventMeta()) + am.StoreEvent(ctx, userID, newKey.Id, accountID, activity.SetupKeyRevoked, newKey.EventMeta()) } defer func() { @@ -301,10 +302,10 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup for _, g := range removedGroups { group := account.GetGroup(g) if group != nil { - am.StoreEvent(userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey, + am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupRemovedFromSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) } else { - log.Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) + log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) } } @@ -312,24 +313,24 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup for _, g := range addedGroups { group := account.GetGroup(g) if group != nil { - am.StoreEvent(userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey, + am.StoreEvent(ctx, userID, oldKey.Id, accountID, activity.GroupAddedToSetupKey, map[string]any{"group": group.Name, "group_id": group.ID, "setupkey": newKey.Name}) } else { - log.Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) + log.WithContext(ctx).Errorf("group %s not found while saving setup key activity event of account %s", g, account.Id) } } }() - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return newKey, nil } // ListSetupKeys returns a list of all setup keys of the account -func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*SetupKey, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -358,11 +359,11 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set } // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. -func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 43edabbd6..034f4e2d6 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "strconv" "testing" @@ -20,12 +21,12 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(userID, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") if err != nil { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -37,7 +38,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { expiresIn := time.Hour keyName := "my-test-key" - key, err := manager.CreateSetupKey(account.Id, keyName, SetupKeyReusable, expiresIn, []string{}, + key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, SetupKeyReusable, expiresIn, []string{}, SetupKeyUnlimitedUsage, userID, false) if err != nil { t.Fatal(err) @@ -46,7 +47,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { autoGroups := []string{"group_1", "group_2"} newKeyName := "my-new-test-key" revoked := true - newKey, err := manager.SaveSetupKey(account.Id, &SetupKey{ + newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ Id: key.Id, Name: newKeyName, Revoked: revoked, @@ -78,12 +79,12 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(userID, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") if err != nil { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -92,7 +93,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, @@ -136,7 +137,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { for _, tCase := range []testCase{testCase1, testCase2} { t.Run(tCase.name, func(t *testing.T) { - key, err := manager.CreateSetupKey(account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn, + key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn, tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false) if tCase.expectedFailure { @@ -174,12 +175,12 @@ func TestGetSetupKeys(t *testing.T) { } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(userID, "") + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") if err != nil { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -188,7 +189,7 @@ func TestGetSetupKeys(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 35f09d60c..b5ae82828 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1,6 +1,7 @@ package server import ( + "context" "encoding/json" "errors" "fmt" @@ -52,7 +53,7 @@ type installation struct { type migrationFunc func(*gorm.DB) error // NewSqlStore creates a new SqlStore instance. -func NewSqlStore(db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetrics) (*SqlStore, error) { +func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetrics) (*SqlStore, error) { sql, err := db.DB() if err != nil { return nil, err @@ -60,7 +61,7 @@ func NewSqlStore(db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetr conns := runtime.NumCPU() sql.SetMaxOpenConns(conns) // TODO: make it configurable - if err := migrate(db); err != nil { + if err := migrate(ctx, db); err != nil { return nil, fmt.Errorf("migrate: %w", err) } err = db.AutoMigrate( @@ -76,18 +77,18 @@ func NewSqlStore(db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetr } // AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock -func (s *SqlStore) AcquireGlobalLock() (unlock func()) { - log.Tracef("acquiring global lock") +func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) { + log.WithContext(ctx).Tracef("acquiring global lock") start := time.Now() s.globalAccountLock.Lock() unlock = func() { s.globalAccountLock.Unlock() - log.Tracef("released global lock in %v", time.Since(start)) + log.WithContext(ctx).Tracef("released global lock in %v", time.Since(start)) } took := time.Since(start) - log.Tracef("took %v to acquire global lock", took) + log.WithContext(ctx).Tracef("took %v to acquire global lock", took) if s.metrics != nil { s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took) } @@ -95,8 +96,8 @@ func (s *SqlStore) AcquireGlobalLock() (unlock func()) { return unlock } -func (s *SqlStore) AcquireAccountWriteLock(accountID string) (unlock func()) { - log.Tracef("acquiring write lock for account %s", accountID) +func (s *SqlStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) { + log.WithContext(ctx).Tracef("acquiring write lock for account %s", accountID) start := time.Now() value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{}) @@ -105,14 +106,14 @@ func (s *SqlStore) AcquireAccountWriteLock(accountID string) (unlock func()) { unlock = func() { mtx.Unlock() - log.Tracef("released write lock for account %s in %v", accountID, time.Since(start)) + log.WithContext(ctx).Tracef("released write lock for account %s in %v", accountID, time.Since(start)) } return unlock } -func (s *SqlStore) AcquireAccountReadLock(accountID string) (unlock func()) { - log.Tracef("acquiring read lock for account %s", accountID) +func (s *SqlStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) { + log.WithContext(ctx).Tracef("acquiring read lock for account %s", accountID) start := time.Now() value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{}) @@ -121,17 +122,17 @@ func (s *SqlStore) AcquireAccountReadLock(accountID string) (unlock func()) { unlock = func() { mtx.RUnlock() - log.Tracef("released read lock for account %s in %v", accountID, time.Since(start)) + log.WithContext(ctx).Tracef("released read lock for account %s in %v", accountID, time.Since(start)) } return unlock } -func (s *SqlStore) SaveAccount(account *Account) error { +func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error { start := time.Now() // todo: remove this check after the issue is resolved - s.checkAccountDomainBeforeSave(account.Id, account.Domain) + s.checkAccountDomainBeforeSave(ctx, account.Id, account.Domain) generateAccountSQLTypes(account) @@ -165,7 +166,7 @@ func (s *SqlStore) SaveAccount(account *Account) error { if s.metrics != nil { s.metrics.StoreMetrics().CountPersistenceDuration(took) } - log.Debugf("took %d ms to persist an account to the store", took.Milliseconds()) + log.WithContext(ctx).Debugf("took %d ms to persist an account to the store", took.Milliseconds()) return err } @@ -207,22 +208,22 @@ func generateAccountSQLTypes(account *Account) { } // checkAccountDomainBeforeSave temporary method to troubleshoot an issue with domains getting blank -func (s *SqlStore) checkAccountDomainBeforeSave(accountID, newDomain string) { +func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, newDomain string) { var acc Account var domain string result := s.db.Model(&acc).Select("domain").Where(idQueryCondition, accountID).First(&domain) if result.Error != nil { if !errors.Is(result.Error, gorm.ErrRecordNotFound) { - log.Errorf("error when getting account %s from the store to check domain: %s", accountID, result.Error) + log.WithContext(ctx).Errorf("error when getting account %s from the store to check domain: %s", accountID, result.Error) } return } if domain != "" && newDomain == "" { - log.Warnf("saving an account with empty domain when there was a domain set. Previous domain %s, Account ID: %s, Trace: %s", domain, accountID, debug.Stack()) + log.WithContext(ctx).Warnf("saving an account with empty domain when there was a domain set. Previous domain %s, Account ID: %s, Trace: %s", domain, accountID, debug.Stack()) } } -func (s *SqlStore) DeleteAccount(account *Account) error { +func (s *SqlStore) DeleteAccount(ctx context.Context, account *Account) error { start := time.Now() err := s.db.Transaction(func(tx *gorm.DB) error { @@ -248,12 +249,12 @@ func (s *SqlStore) DeleteAccount(account *Account) error { if s.metrics != nil { s.metrics.StoreMetrics().CountPersistenceDuration(took) } - log.Debugf("took %d ms to delete an account to the store", took.Milliseconds()) + log.WithContext(ctx).Debugf("took %d ms to delete an account to the store", took.Milliseconds()) return err } -func (s *SqlStore) SaveInstallationID(ID string) error { +func (s *SqlStore) SaveInstallationID(_ context.Context, ID string) error { installation := installation{InstallationIDValue: ID} installation.ID = uint(s.installationPK) @@ -320,7 +321,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error { return nil } -func (s *SqlStore) GetAccountByPrivateDomain(domain string) (*Account, error) { +func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) { var account Account result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?", @@ -329,22 +330,22 @@ func (s *SqlStore) GetAccountByPrivateDomain(domain string) (*Account, error) { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") } - log.Errorf("error when getting account from the store: %s", result.Error) + log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting account from store") } // TODO: rework to not call GetAccount - return s.GetAccount(account.Id) + return s.GetAccount(ctx, account.Id) } -func (s *SqlStore) GetAccountBySetupKey(setupKey string) (*Account, error) { +func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { var key SetupKey result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey)) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.Errorf("error when getting setup key from the store: %s", result.Error) + log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting setup key from store") } @@ -352,31 +353,31 @@ func (s *SqlStore) GetAccountBySetupKey(setupKey string) (*Account, error) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return s.GetAccount(key.AccountID) + return s.GetAccount(ctx, key.AccountID) } -func (s *SqlStore) GetTokenIDByHashedToken(hashedToken string) (string, error) { +func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) { var token PersonalAccessToken result := s.db.First(&token, "hashed_token = ?", hashedToken) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.Errorf("error when getting token from the store: %s", result.Error) + log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) return "", status.Errorf(status.Internal, "issue getting account from store") } return token.ID, nil } -func (s *SqlStore) GetUserByTokenID(tokenID string) (*User, error) { +func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) { var token PersonalAccessToken result := s.db.First(&token, idQueryCondition, tokenID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.Errorf("error when getting token from the store: %s", result.Error) + log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting account from store") } @@ -398,7 +399,7 @@ func (s *SqlStore) GetUserByTokenID(tokenID string) (*User, error) { return &user, nil } -func (s *SqlStore) GetAllAccounts() (all []*Account) { +func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) { var accounts []Account result := s.db.Find(&accounts) if result.Error != nil { @@ -406,7 +407,7 @@ func (s *SqlStore) GetAllAccounts() (all []*Account) { } for _, account := range accounts { - if acc, err := s.GetAccount(account.Id); err == nil { + if acc, err := s.GetAccount(ctx, account.Id); err == nil { all = append(all, acc) } } @@ -414,7 +415,7 @@ func (s *SqlStore) GetAllAccounts() (all []*Account) { return all } -func (s *SqlStore) GetAccount(accountID string) (*Account, error) { +func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, error) { var account Account result := s.db.Model(&account). @@ -422,7 +423,7 @@ func (s *SqlStore) GetAccount(accountID string) (*Account, error) { Preload(clause.Associations). First(&account, idQueryCondition, accountID) if result.Error != nil { - log.Errorf("error when getting account %s from the store: %s", accountID, result.Error) + log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found") } @@ -482,7 +483,7 @@ func (s *SqlStore) GetAccount(accountID string) (*Account, error) { return &account, nil } -func (s *SqlStore) GetAccountByUser(userID string) (*Account, error) { +func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) { var user User result := s.db.Select("account_id").First(&user, idQueryCondition, userID) if result.Error != nil { @@ -496,17 +497,17 @@ func (s *SqlStore) GetAccountByUser(userID string) (*Account, error) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return s.GetAccount(user.AccountID) + return s.GetAccount(ctx, user.AccountID) } -func (s *SqlStore) GetAccountByPeerID(peerID string) (*Account, error) { +func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { var peer nbpeer.Peer result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.Errorf("error when getting peer from the store: %s", result.Error) + log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting account from store") } @@ -514,10 +515,10 @@ func (s *SqlStore) GetAccountByPeerID(peerID string) (*Account, error) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return s.GetAccount(peer.AccountID) + return s.GetAccount(ctx, peer.AccountID) } -func (s *SqlStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { +func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { var peer nbpeer.Peer result := s.db.Select("account_id").First(&peer, "key = ?", peerKey) @@ -525,7 +526,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.Errorf("error when getting peer from the store: %s", result.Error) + log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting account from store") } @@ -533,10 +534,10 @@ func (s *SqlStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return s.GetAccount(peer.AccountID) + return s.GetAccount(ctx, peer.AccountID) } -func (s *SqlStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) { +func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { var peer nbpeer.Peer var accountID string result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID) @@ -544,7 +545,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.Errorf("error when getting peer from the store: %s", result.Error) + log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return "", status.Errorf(status.Internal, "issue getting account from store") } @@ -565,7 +566,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { return accountID, nil } -func (s *SqlStore) GetAccountIDBySetupKey(setupKey string) (string, error) { +func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { var key SetupKey var accountID string result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID) @@ -573,34 +574,34 @@ func (s *SqlStore) GetAccountIDBySetupKey(setupKey string) (string, error) { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.Errorf("error when getting setup key from the store: %s", result.Error) + log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error) return "", status.Errorf(status.Internal, "issue getting setup key from store") } return accountID, nil } -func (s *SqlStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) { +func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) { var peer nbpeer.Peer result := s.db.First(&peer, "key = ?", peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "peer not found") } - log.Errorf("error when getting peer from the store: %s", result.Error) + log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting peer from store") } return &peer, nil } -func (s *SqlStore) GetAccountSettings(accountID string) (*Settings, error) { +func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) { var accountSettings AccountSettings if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } - log.Errorf("error when getting settings from the store: %s", err) + log.WithContext(ctx).Errorf("error when getting settings from the store: %s", err) return nil, status.Errorf(status.Internal, "issue getting settings from store") } return accountSettings.Settings, nil @@ -639,7 +640,7 @@ func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *p } // Close closes the underlying DB connection -func (s *SqlStore) Close() error { +func (s *SqlStore) Close(_ context.Context) error { sql, err := s.db.DB() if err != nil { return fmt.Errorf("get db: %w", err) @@ -653,7 +654,7 @@ func (s *SqlStore) GetStoreEngine() StoreEngine { } // NewSqliteStore creates a new SQLite store. -func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) { +func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) { storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName) if runtime.GOOS == "windows" { // Vo avoid `The process cannot access the file because it is being used by another process` on Windows @@ -670,11 +671,11 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqlStore, er return nil, err } - return NewSqlStore(db, SqliteStoreEngine, metrics) + return NewSqlStore(ctx, db, SqliteStoreEngine, metrics) } // NewPostgresqlStore creates a new Postgres store. -func NewPostgresqlStore(dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { +func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), PrepareStmt: true, @@ -683,32 +684,32 @@ func NewPostgresqlStore(dsn string, metrics telemetry.AppMetrics) (*SqlStore, er return nil, err } - return NewSqlStore(db, PostgresStoreEngine, metrics) + return NewSqlStore(ctx, db, PostgresStoreEngine, metrics) } // newPostgresStore initializes a new Postgres store. -func newPostgresStore(metrics telemetry.AppMetrics) (Store, error) { +func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics) (Store, error) { dsn, ok := os.LookupEnv(postgresDsnEnv) if !ok { return nil, fmt.Errorf("%s is not set", postgresDsnEnv) } - return NewPostgresqlStore(dsn, metrics) + return NewPostgresqlStore(ctx, dsn, metrics) } // NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir. -func NewSqliteStoreFromFileStore(fileStore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) { - store, err := NewSqliteStore(dataDir, metrics) +func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) { + store, err := NewSqliteStore(ctx, dataDir, metrics) if err != nil { return nil, err } - err = store.SaveInstallationID(fileStore.InstallationID) + err = store.SaveInstallationID(ctx, fileStore.InstallationID) if err != nil { return nil, err } - for _, account := range fileStore.GetAllAccounts() { - err := store.SaveAccount(account) + for _, account := range fileStore.GetAllAccounts(ctx) { + err := store.SaveAccount(ctx, account) if err != nil { return nil, err } @@ -718,19 +719,19 @@ func NewSqliteStoreFromFileStore(fileStore *FileStore, dataDir string, metrics t } // NewPostgresqlStoreFromFileStore restores a store from FileStore and stores Postgres DB. -func NewPostgresqlStoreFromFileStore(fileStore *FileStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { - store, err := NewPostgresqlStore(dsn, metrics) +func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { + store, err := NewPostgresqlStore(ctx, dsn, metrics) if err != nil { return nil, err } - err = store.SaveInstallationID(fileStore.InstallationID) + err = store.SaveInstallationID(ctx, fileStore.InstallationID) if err != nil { return nil, err } - for _, account := range fileStore.GetAllAccounts() { - err := store.SaveAccount(account) + for _, account := range fileStore.GetAllAccounts(ctx) { + err := store.SaveAccount(ctx, account) if err != nil { return nil, err } diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index a41195206..e0e893052 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "math/rand" "net" @@ -34,7 +35,7 @@ func TestSqlite_NewStore(t *testing.T) { store := newSqliteStore(t) - if len(store.GetAllAccounts()) != 0 { + if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") } } @@ -46,7 +47,7 @@ func TestSqlite_SaveAccount_Large(t *testing.T) { store := newSqliteStore(t) - account := newAccountWithId("account_id", "testuser", "") + account := newAccountWithId(context.Background(), "account_id", "testuser", "") groupALL, err := account.GetGroupAll() if err != nil { t.Fatal(err) @@ -117,14 +118,14 @@ func TestSqlite_SaveAccount_Large(t *testing.T) { account.SetupKeys[setupKey.Key] = setupKey } - err = store.SaveAccount(account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) - if len(store.GetAllAccounts()) != 1 { + if len(store.GetAllAccounts(context.Background())) != 1 { t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") } - a, err := store.GetAccount(account.Id) + a, err := store.GetAccount(context.Background(), account.Id) if a == nil { t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) } @@ -191,7 +192,7 @@ func TestSqlite_SaveAccount(t *testing.T) { store := newSqliteStore(t) - account := newAccountWithId("account_id", "testuser", "") + account := newAccountWithId(context.Background(), "account_id", "testuser", "") setupKey := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ @@ -203,10 +204,10 @@ func TestSqlite_SaveAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) require.NoError(t, err) - account2 := newAccountWithId("account_id2", "testuser2", "") + account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") setupKey = GenerateDefaultSetupKey() account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ @@ -218,14 +219,14 @@ func TestSqlite_SaveAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - err = store.SaveAccount(account2) + err = store.SaveAccount(context.Background(), account2) require.NoError(t, err) - if len(store.GetAllAccounts()) != 2 { + if len(store.GetAllAccounts(context.Background())) != 2 { t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") } - a, err := store.GetAccount(account.Id) + a, err := store.GetAccount(context.Background(), account.Id) if a == nil { t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) } @@ -239,19 +240,19 @@ func TestSqlite_SaveAccount(t *testing.T) { return } - if a, err := store.GetAccountByPeerPubKey("peerkey"); a == nil { + if a, err := store.GetAccountByPeerPubKey(context.Background(), "peerkey"); a == nil { t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount(): %v", err) } - if a, err := store.GetAccountByUser("testuser"); a == nil { + if a, err := store.GetAccountByUser(context.Background(), "testuser"); a == nil { t.Errorf("expecting UserID2AccountID index updated after SaveAccount(): %v", err) } - if a, err := store.GetAccountByPeerID("testpeer"); a == nil { + if a, err := store.GetAccountByPeerID(context.Background(), "testpeer"); a == nil { t.Errorf("expecting PeerID2AccountID index updated after SaveAccount(): %v", err) } - if a, err := store.GetAccountBySetupKey(setupKey.Key); a == nil { + if a, err := store.GetAccountBySetupKey(context.Background(), setupKey.Key); a == nil { t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount(): %v", err) } } @@ -270,7 +271,7 @@ func TestSqlite_DeleteAccount(t *testing.T) { Name: "test token", }} - account := newAccountWithId("account_id", testUserID, "") + account := newAccountWithId(context.Background(), "account_id", testUserID, "") setupKey := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ @@ -283,33 +284,33 @@ func TestSqlite_DeleteAccount(t *testing.T) { } account.Users[testUserID] = user - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) require.NoError(t, err) - if len(store.GetAllAccounts()) != 1 { + if len(store.GetAllAccounts(context.Background())) != 1 { t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") } - err = store.DeleteAccount(account) + err = store.DeleteAccount(context.Background(), account) require.NoError(t, err) - if len(store.GetAllAccounts()) != 0 { + if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()") } - _, err = store.GetAccountByPeerPubKey("peerkey") + _, err = store.GetAccountByPeerPubKey(context.Background(), "peerkey") require.Error(t, err, "expecting error after removing DeleteAccount when getting account by peer public key") - _, err = store.GetAccountByUser("testuser") + _, err = store.GetAccountByUser(context.Background(), "testuser") require.Error(t, err, "expecting error after removing DeleteAccount when getting account by user") - _, err = store.GetAccountByPeerID("testpeer") + _, err = store.GetAccountByPeerID(context.Background(), "testpeer") require.Error(t, err, "expecting error after removing DeleteAccount when getting account by peer id") - _, err = store.GetAccountBySetupKey(setupKey.Key) + _, err = store.GetAccountBySetupKey(context.Background(), setupKey.Key) require.Error(t, err, "expecting error after removing DeleteAccount when getting account by setup key") - _, err = store.GetAccount(account.Id) + _, err = store.GetAccount(context.Background(), account.Id) require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id") for _, policy := range account.Policies { @@ -339,11 +340,11 @@ func TestSqlite_GetAccount(t *testing.T) { id := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - account, err := store.GetAccount(id) + account, err := store.GetAccount(context.Background(), id) require.NoError(t, err) require.Equal(t, id, account.Id, "account id should match") - _, err = store.GetAccount("non-existing-account") + _, err = store.GetAccount(context.Background(), "non-existing-account") assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -357,7 +358,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { store := newSqliteStoreFromFile(t, "testdata/store.json") - account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") + account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) // save status of non-existing peer @@ -379,13 +380,13 @@ func TestSqlite_SavePeerStatus(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, } - err = store.SaveAccount(account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) err = store.SavePeerStatus(account.Id, "testpeer", newStatus) require.NoError(t, err) - account, err = store.GetAccount(account.Id) + account, err = store.GetAccount(context.Background(), account.Id) require.NoError(t, err) actual := account.Peers["testpeer"].Status @@ -398,7 +399,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { store := newSqliteStoreFromFile(t, "testdata/store.json") - account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") + account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) peer := &nbpeer.Peer{ @@ -417,7 +418,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { assert.Error(t, err) account.Peers[peer.ID] = peer - err = store.SaveAccount(account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) peer.Location.ConnectionIP = net.ParseIP("35.1.1.1") @@ -428,7 +429,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { err = store.SavePeerLocation(account.Id, account.Peers[peer.ID]) assert.NoError(t, err) - account, err = store.GetAccount(account.Id) + account, err = store.GetAccount(context.Background(), account.Id) require.NoError(t, err) actual := account.Peers[peer.ID].Location @@ -451,11 +452,11 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) { existingDomain := "test.com" - account, err := store.GetAccountByPrivateDomain(existingDomain) + account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) require.NoError(t, err, "should found account") require.Equal(t, existingDomain, account.Domain, "domains should match") - _, err = store.GetAccountByPrivateDomain("missing-domain.com") + _, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com") require.Error(t, err, "should return error on domain lookup") parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -472,11 +473,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" - token, err := store.GetTokenIDByHashedToken(hashed) + token, err := store.GetTokenIDByHashedToken(context.Background(), hashed) require.NoError(t, err) require.Equal(t, id, token) - _, err = store.GetTokenIDByHashedToken("non-existing-hash") + _, err = store.GetTokenIDByHashedToken(context.Background(), "non-existing-hash") require.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -492,11 +493,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { id := "9dj38s35-63fb-11ec-90d6-0242ac120003" - user, err := store.GetUserByTokenID(id) + user, err := store.GetUserByTokenID(context.Background(), id) require.NoError(t, err) require.Equal(t, id, user.PATs[id].ID) - _, err = store.GetUserByTokenID("non-existing-id") + _, err = store.GetUserByTokenID(context.Background(), "non-existing-id") require.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -510,7 +511,7 @@ func TestMigrate(t *testing.T) { store := newSqliteStore(t) - err := migrate(store.db) + err := migrate(context.Background(), store.db) require.NoError(t, err, "Migration should not fail on empty db") _, ipnet, err := net.ParseCIDR("10.0.0.0/24") @@ -565,10 +566,10 @@ func TestMigrate(t *testing.T) { err = store.db.Save(rt).Error require.NoError(t, err, "Failed to insert Gob data") - err = migrate(store.db) + err = migrate(context.Background(), store.db) require.NoError(t, err, "Migration should not fail on gob populated db") - err = migrate(store.db) + err = migrate(context.Background(), store.db) require.NoError(t, err, "Migration should not fail on migrated db") err = store.db.Delete(rt).Where("id = ?", "route1").Error @@ -584,10 +585,10 @@ func TestMigrate(t *testing.T) { err = store.db.Save(nRT).Error require.NoError(t, err, "Failed to insert json nil slice data") - err = migrate(store.db) + err = migrate(context.Background(), store.db) require.NoError(t, err, "Migration should not fail on json nil slice populated db") - err = migrate(store.db) + err = migrate(context.Background(), store.db) require.NoError(t, err, "Migration should not fail on migrated db") } @@ -595,7 +596,7 @@ func TestMigrate(t *testing.T) { func newSqliteStore(t *testing.T) *SqlStore { t.Helper() - store, err := NewSqliteStore(t.TempDir(), nil) + store, err := NewSqliteStore(context.Background(), t.TempDir(), nil) require.NoError(t, err) require.NotNil(t, store) @@ -610,10 +611,10 @@ func newSqliteStoreFromFile(t *testing.T, filename string) *SqlStore { err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json")) require.NoError(t, err) - fStore, err := NewFileStore(storeDir, nil) + fStore, err := NewFileStore(context.Background(), storeDir, nil) require.NoError(t, err) - store, err := NewSqliteStoreFromFileStore(fStore, storeDir, nil) + store, err := NewSqliteStoreFromFileStore(context.Background(), fStore, storeDir, nil) require.NoError(t, err) require.NotNil(t, store) @@ -622,7 +623,7 @@ func newSqliteStoreFromFile(t *testing.T, filename string) *SqlStore { func newAccount(store Store, id int) error { str := fmt.Sprintf("%s-%d", uuid.New().String(), id) - account := newAccountWithId(str, str+"-testuser", "example.com") + account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com") setupKey := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["p"+str] = &nbpeer.Peer{ @@ -634,7 +635,7 @@ func newAccount(store Store, id int) error { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - return store.SaveAccount(account) + return store.SaveAccount(context.Background(), account) } func newPostgresqlStore(t *testing.T) *SqlStore { @@ -651,7 +652,7 @@ func newPostgresqlStore(t *testing.T) *SqlStore { t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv) } - store, err := NewPostgresqlStore(postgresDsn, nil) + store, err := NewPostgresqlStore(context.Background(), postgresDsn, nil) if err != nil { t.Fatalf("could not initialize postgresql store: %s", err) } @@ -668,7 +669,7 @@ func newPostgresqlStoreFromFile(t *testing.T, filename string) *SqlStore { err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json")) require.NoError(t, err) - fStore, err := NewFileStore(storeDir, nil) + fStore, err := NewFileStore(context.Background(), storeDir, nil) require.NoError(t, err) cleanUp, err := testutil.CreatePGDB() @@ -682,7 +683,7 @@ func newPostgresqlStoreFromFile(t *testing.T, filename string) *SqlStore { t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv) } - store, err := NewPostgresqlStoreFromFileStore(fStore, postgresDsn, nil) + store, err := NewPostgresqlStoreFromFileStore(context.Background(), fStore, postgresDsn, nil) require.NoError(t, err) require.NotNil(t, store) @@ -696,7 +697,7 @@ func TestPostgresql_NewStore(t *testing.T) { store := newPostgresqlStore(t) - if len(store.GetAllAccounts()) != 0 { + if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") } } @@ -708,7 +709,7 @@ func TestPostgresql_SaveAccount(t *testing.T) { store := newPostgresqlStore(t) - account := newAccountWithId("account_id", "testuser", "") + account := newAccountWithId(context.Background(), "account_id", "testuser", "") setupKey := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ @@ -720,10 +721,10 @@ func TestPostgresql_SaveAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) require.NoError(t, err) - account2 := newAccountWithId("account_id2", "testuser2", "") + account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") setupKey = GenerateDefaultSetupKey() account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ @@ -735,14 +736,14 @@ func TestPostgresql_SaveAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - err = store.SaveAccount(account2) + err = store.SaveAccount(context.Background(), account2) require.NoError(t, err) - if len(store.GetAllAccounts()) != 2 { + if len(store.GetAllAccounts(context.Background())) != 2 { t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") } - a, err := store.GetAccount(account.Id) + a, err := store.GetAccount(context.Background(), account.Id) if a == nil { t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) } @@ -756,19 +757,19 @@ func TestPostgresql_SaveAccount(t *testing.T) { return } - if a, err := store.GetAccountByPeerPubKey("peerkey"); a == nil { + if a, err := store.GetAccountByPeerPubKey(context.Background(), "peerkey"); a == nil { t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount(): %v", err) } - if a, err := store.GetAccountByUser("testuser"); a == nil { + if a, err := store.GetAccountByUser(context.Background(), "testuser"); a == nil { t.Errorf("expecting UserID2AccountID index updated after SaveAccount(): %v", err) } - if a, err := store.GetAccountByPeerID("testpeer"); a == nil { + if a, err := store.GetAccountByPeerID(context.Background(), "testpeer"); a == nil { t.Errorf("expecting PeerID2AccountID index updated after SaveAccount(): %v", err) } - if a, err := store.GetAccountBySetupKey(setupKey.Key); a == nil { + if a, err := store.GetAccountBySetupKey(context.Background(), setupKey.Key); a == nil { t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount(): %v", err) } } @@ -787,7 +788,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) { Name: "test token", }} - account := newAccountWithId("account_id", testUserID, "") + account := newAccountWithId(context.Background(), "account_id", testUserID, "") setupKey := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ @@ -800,33 +801,33 @@ func TestPostgresql_DeleteAccount(t *testing.T) { } account.Users[testUserID] = user - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) require.NoError(t, err) - if len(store.GetAllAccounts()) != 1 { + if len(store.GetAllAccounts(context.Background())) != 1 { t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") } - err = store.DeleteAccount(account) + err = store.DeleteAccount(context.Background(), account) require.NoError(t, err) - if len(store.GetAllAccounts()) != 0 { + if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()") } - _, err = store.GetAccountByPeerPubKey("peerkey") + _, err = store.GetAccountByPeerPubKey(context.Background(), "peerkey") require.Error(t, err, "expecting error after removing DeleteAccount when getting account by peer public key") - _, err = store.GetAccountByUser("testuser") + _, err = store.GetAccountByUser(context.Background(), "testuser") require.Error(t, err, "expecting error after removing DeleteAccount when getting account by user") - _, err = store.GetAccountByPeerID("testpeer") + _, err = store.GetAccountByPeerID(context.Background(), "testpeer") require.Error(t, err, "expecting error after removing DeleteAccount when getting account by peer id") - _, err = store.GetAccountBySetupKey(setupKey.Key) + _, err = store.GetAccountBySetupKey(context.Background(), setupKey.Key) require.Error(t, err, "expecting error after removing DeleteAccount when getting account by setup key") - _, err = store.GetAccount(account.Id) + _, err = store.GetAccount(context.Background(), account.Id) require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id") for _, policy := range account.Policies { @@ -854,7 +855,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { store := newPostgresqlStoreFromFile(t, "testdata/store.json") - account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") + account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) // save status of non-existing peer @@ -873,13 +874,13 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, } - err = store.SaveAccount(account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) err = store.SavePeerStatus(account.Id, "testpeer", newStatus) require.NoError(t, err) - account, err = store.GetAccount(account.Id) + account, err = store.GetAccount(context.Background(), account.Id) require.NoError(t, err) actual := account.Peers["testpeer"].Status @@ -895,11 +896,11 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { existingDomain := "test.com" - account, err := store.GetAccountByPrivateDomain(existingDomain) + account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) require.NoError(t, err, "should found account") require.Equal(t, existingDomain, account.Domain, "domains should match") - _, err = store.GetAccountByPrivateDomain("missing-domain.com") + _, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com") require.Error(t, err, "should return error on domain lookup") } @@ -913,7 +914,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" - token, err := store.GetTokenIDByHashedToken(hashed) + token, err := store.GetTokenIDByHashedToken(context.Background(), hashed) require.NoError(t, err) require.Equal(t, id, token) } @@ -927,7 +928,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { id := "9dj38s35-63fb-11ec-90d6-0242ac120003" - user, err := store.GetUserByTokenID(id) + user, err := store.GetUserByTokenID(context.Background(), id) require.NoError(t, err) require.Equal(t, id, user.PATs[id].ID) } diff --git a/management/server/store.go b/management/server/store.go index 67ef20884..05a09b3ee 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -1,6 +1,7 @@ package server import ( + "context" "errors" "fmt" "net" @@ -11,11 +12,12 @@ import ( "strings" "time" - "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/util" log "github.com/sirupsen/logrus" "gorm.io/gorm" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/management/server/migration" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" @@ -24,41 +26,41 @@ import ( ) type Store interface { - GetAllAccounts() []*Account - GetAccount(accountID string) (*Account, error) - DeleteAccount(account *Account) error - GetAccountByUser(userID string) (*Account, error) - GetAccountByPeerPubKey(peerKey string) (*Account, error) - GetAccountIDByPeerPubKey(peerKey string) (string, error) + GetAllAccounts(ctx context.Context) []*Account + GetAccount(ctx context.Context, accountID string) (*Account, error) + DeleteAccount(ctx context.Context, account *Account) error + GetAccountByUser(ctx context.Context, userID string) (*Account, error) + GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) + GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByUserID(peerKey string) (string, error) - GetAccountIDBySetupKey(peerKey string) (string, error) - GetAccountByPeerID(peerID string) (*Account, error) - GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later - GetAccountByPrivateDomain(domain string) (*Account, error) - GetTokenIDByHashedToken(secret string) (string, error) - GetUserByTokenID(tokenID string) (*User, error) + GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) + GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) + GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later + GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) + GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) + GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) - SaveAccount(account *Account) error + SaveAccount(ctx context.Context, account *Account) error DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error GetInstallationID() string - SaveInstallationID(ID string) error + SaveInstallationID(ctx context.Context, ID string) error // AcquireAccountWriteLock should attempt to acquire account lock for write purposes and return a function that releases the lock - AcquireAccountWriteLock(accountID string) func() + AcquireAccountWriteLock(ctx context.Context, accountID string) func() // AcquireAccountReadLock should attempt to acquire account lock for read purposes and return a function that releases the lock - AcquireAccountReadLock(accountID string) func() + AcquireAccountReadLock(ctx context.Context, accountID string) func() // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock - AcquireGlobalLock() func() + AcquireGlobalLock(ctx context.Context) func() SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error // Close should close the store persisting all unsaved data. - Close() error + Close(ctx context.Context) error // GetStoreEngine should return StoreEngine of the current store implementation. // This is also a method of metrics.DataSource interface. GetStoreEngine() StoreEngine - GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) - GetAccountSettings(accountID string) (*Settings, error) + GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) + GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) } type StoreEngine string @@ -90,7 +92,7 @@ func getStoreEngineFromEnv() StoreEngine { // If no engine is specified, it attempts to retrieve it from the environment. // If still not specified, it defaults to using SQLite. // Additionally, it handles the migration from a JSON store file to SQLite if applicable. -func getStoreEngine(dataDir string, kind StoreEngine) StoreEngine { +func getStoreEngine(ctx context.Context, dataDir string, kind StoreEngine) StoreEngine { if kind == "" { kind = getStoreEngineFromEnv() if kind == "" { @@ -101,11 +103,11 @@ func getStoreEngine(dataDir string, kind StoreEngine) StoreEngine { sqliteStoreFile := filepath.Join(dataDir, storeSqliteFileName) if util.FileExists(jsonStoreFile) && !util.FileExists(sqliteStoreFile) { - log.Warnf("unsupported store engine specified, but found %s. Automatically migrating to SQLite.", jsonStoreFile) + log.WithContext(ctx).Warnf("unsupported store engine specified, but found %s. Automatically migrating to SQLite.", jsonStoreFile) // Attempt to migrate from JSON store to SQLite - if err := MigrateFileStoreToSqlite(dataDir); err != nil { - log.Errorf("failed to migrate filestore to SQLite: %v", err) + if err := MigrateFileStoreToSqlite(ctx, dataDir); err != nil { + log.WithContext(ctx).Errorf("failed to migrate filestore to SQLite: %v", err) kind = FileStoreEngine } } @@ -116,8 +118,8 @@ func getStoreEngine(dataDir string, kind StoreEngine) StoreEngine { } // NewStore creates a new store based on the provided engine type, data directory, and telemetry metrics -func NewStore(kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (Store, error) { - kind = getStoreEngine(dataDir, kind) +func NewStore(ctx context.Context, kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (Store, error) { + kind = getStoreEngine(ctx, dataDir, kind) if err := checkFileStoreEngine(kind, dataDir); err != nil { return nil, err @@ -125,11 +127,11 @@ func NewStore(kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (S switch kind { case SqliteStoreEngine: - log.Info("using SQLite store engine") - return NewSqliteStore(dataDir, metrics) + log.WithContext(ctx).Info("using SQLite store engine") + return NewSqliteStore(ctx, dataDir, metrics) case PostgresStoreEngine: - log.Info("using Postgres store engine") - return newPostgresStore(metrics) + log.WithContext(ctx).Info("using Postgres store engine") + return newPostgresStore(ctx, metrics) default: return nil, fmt.Errorf("unsupported kind of store: %s", kind) } @@ -147,8 +149,8 @@ func checkFileStoreEngine(kind StoreEngine, dataDir string) error { } // migrate migrates the SQLite database to the latest schema -func migrate(db *gorm.DB) error { - migrations := getMigrations() +func migrate(ctx context.Context, db *gorm.DB) error { + migrations := getMigrations(ctx) for _, m := range migrations { if err := m(db); err != nil { @@ -159,29 +161,29 @@ func migrate(db *gorm.DB) error { return nil } -func getMigrations() []migrationFunc { +func getMigrations(ctx context.Context) []migrationFunc { return []migrationFunc{ func(db *gorm.DB) error { - return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](db, "network_net") + return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](ctx, db, "network_net") }, func(db *gorm.DB) error { - return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](db, "network") + return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](ctx, db, "network") }, func(db *gorm.DB) error { - return migration.MigrateFieldFromGobToJSON[route.Route, []string](db, "peer_groups") + return migration.MigrateFieldFromGobToJSON[route.Route, []string](ctx, db, "peer_groups") }, func(db *gorm.DB) error { - return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "") + return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "location_connection_ip", "") }, func(db *gorm.DB) error { - return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip") + return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "ip", "idx_peers_account_id_ip") }, } } // NewTestStoreFromJson is only used in tests -func NewTestStoreFromJson(dataDir string) (Store, func(), error) { - fstore, err := NewFileStore(dataDir, nil) +func NewTestStoreFromJson(ctx context.Context, dataDir string) (Store, func(), error) { + fstore, err := NewFileStore(ctx, dataDir, nil) if err != nil { return nil, nil, err } @@ -208,23 +210,23 @@ func NewTestStoreFromJson(dataDir string) (Store, func(), error) { return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv) } - store, err = NewPostgresqlStoreFromFileStore(fstore, dsn, nil) + store, err = NewPostgresqlStoreFromFileStore(ctx, fstore, dsn, nil) if err != nil { return nil, nil, err } } else { - store, err = NewSqliteStoreFromFileStore(fstore, dataDir, nil) + store, err = NewSqliteStoreFromFileStore(ctx, fstore, dataDir, nil) if err != nil { return nil, nil, err } - cleanUp = func() { store.Close() } + cleanUp = func() { store.Close(ctx) } } return store, cleanUp, nil } // MigrateFileStoreToSqlite migrates the file store to the SQLite store. -func MigrateFileStoreToSqlite(dataDir string) error { +func MigrateFileStoreToSqlite(ctx context.Context, dataDir string) error { fileStorePath := path.Join(dataDir, storeFileName) if _, err := os.Stat(fileStorePath); errors.Is(err, os.ErrNotExist) { return fmt.Errorf("%s doesn't exist, couldn't continue the operation", fileStorePath) @@ -235,21 +237,21 @@ func MigrateFileStoreToSqlite(dataDir string) error { return fmt.Errorf("%s already exists, couldn't continue the operation", sqlStorePath) } - fstore, err := NewFileStore(dataDir, nil) + fstore, err := NewFileStore(ctx, dataDir, nil) if err != nil { return fmt.Errorf("failed creating file store: %s: %v", dataDir, err) } - fsStoreAccounts := len(fstore.GetAllAccounts()) - log.Infof("%d account will be migrated from file store %s to sqlite store %s", + fsStoreAccounts := len(fstore.GetAllAccounts(ctx)) + log.WithContext(ctx).Infof("%d account will be migrated from file store %s to sqlite store %s", fsStoreAccounts, fileStorePath, sqlStorePath) - store, err := NewSqliteStoreFromFileStore(fstore, dataDir, nil) + store, err := NewSqliteStoreFromFileStore(ctx, fstore, dataDir, nil) if err != nil { return fmt.Errorf("failed creating file store: %s: %v", dataDir, err) } - sqliteStoreAccounts := len(store.GetAllAccounts()) + sqliteStoreAccounts := len(store.GetAllAccounts(ctx)) if fsStoreAccounts != sqliteStoreAccounts { return fmt.Errorf("failed to migrate accounts from file to sqlite. Expected accounts: %d, got: %d", fsStoreAccounts, sqliteStoreAccounts) diff --git a/management/server/store_test.go b/management/server/store_test.go index 3f8c5d18b..40c36c9e0 100644 --- a/management/server/store_test.go +++ b/management/server/store_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "testing" @@ -15,13 +16,13 @@ type benchCase struct { var newFs = func(b *testing.B) Store { b.Helper() - store, _ := NewFileStore(b.TempDir(), nil) + store, _ := NewFileStore(context.Background(), b.TempDir(), nil) return store } var newSqlite = func(b *testing.B) Store { b.Helper() - store, _ := NewSqliteStore(b.TempDir(), nil) + store, _ := NewSqliteStore(context.Background(), b.TempDir(), nil) return store } @@ -76,13 +77,13 @@ func BenchmarkTest_StoreRead(b *testing.B) { _ = newAccount(store, i) } - accounts := store.GetAllAccounts() + accounts := store.GetAllAccounts(context.Background()) id := accounts[c.size-1].Id b.Run(name, func(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - _, _ = store.GetAccount(id) + _, _ = store.GetAccount(context.Background(), id) } }) }) diff --git a/management/server/telemetry/app_metrics.go b/management/server/telemetry/app_metrics.go index 56f4fb9c8..d88e18d8a 100644 --- a/management/server/telemetry/app_metrics.go +++ b/management/server/telemetry/app_metrics.go @@ -22,7 +22,7 @@ const defaultEndpoint = "/metrics" type MockAppMetrics struct { GetMeterFunc func() metric2.Meter CloseFunc func() error - ExposeFunc func(port int, endpoint string) error + ExposeFunc func(ctx context.Context, port int, endpoint string) error IDPMetricsFunc func() *IDPMetrics HTTPMiddlewareFunc func() *HTTPMiddleware GRPCMetricsFunc func() *GRPCMetrics @@ -47,9 +47,9 @@ func (mock *MockAppMetrics) Close() error { } // Expose mocks the Expose function of the AppMetrics interface -func (mock *MockAppMetrics) Expose(port int, endpoint string) error { +func (mock *MockAppMetrics) Expose(ctx context.Context, port int, endpoint string) error { if mock.ExposeFunc != nil { - return mock.ExposeFunc(port, endpoint) + return mock.ExposeFunc(ctx, port, endpoint) } return fmt.Errorf("unimplemented") } @@ -98,7 +98,7 @@ func (mock *MockAppMetrics) UpdateChannelMetrics() *UpdateChannelMetrics { type AppMetrics interface { GetMeter() metric2.Meter Close() error - Expose(port int, endpoint string) error + Expose(ctx context.Context, port int, endpoint string) error IDPMetrics() *IDPMetrics HTTPMiddleware() *HTTPMiddleware GRPCMetrics() *GRPCMetrics @@ -154,7 +154,7 @@ func (appMetrics *defaultAppMetrics) Close() error { // Expose metrics on a given port and endpoint. If endpoint is empty a defaultEndpoint one will be used. // Exposes metrics in the Prometheus format https://prometheus.io/ -func (appMetrics *defaultAppMetrics) Expose(port int, endpoint string) error { +func (appMetrics *defaultAppMetrics) Expose(ctx context.Context, port int, endpoint string) error { if endpoint == "" { endpoint = defaultEndpoint } @@ -174,7 +174,7 @@ func (appMetrics *defaultAppMetrics) Expose(port int, endpoint string) error { } }() - log.Infof("enabled application metrics and exposing on http://%s", listener.Addr().String()) + log.WithContext(ctx).Infof("enabled application metrics and exposing on http://%s", listener.Addr().String()) return nil } diff --git a/management/server/telemetry/http_api_metrics.go b/management/server/telemetry/http_api_metrics.go index c29533661..a80453dca 100644 --- a/management/server/telemetry/http_api_metrics.go +++ b/management/server/telemetry/http_api_metrics.go @@ -3,14 +3,17 @@ package telemetry import ( "context" "fmt" - "hash/fnv" "net/http" "strings" "time" + "github.com/google/uuid" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" + + "github.com/netbirdio/netbird/formatter" + nbContext "github.com/netbirdio/netbird/management/server/context" ) const ( @@ -163,8 +166,15 @@ func getResponseCounterKey(endpoint, method string, status int) string { func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { fn := func(rw http.ResponseWriter, r *http.Request) { reqStart := time.Now() - traceID := hash(fmt.Sprintf("%v", r)) - log.Tracef("HTTP request %v: %v %v", traceID, r.Method, r.URL) + + //nolint + ctx := context.WithValue(r.Context(), formatter.ExecutionContextKey, formatter.HTTPSource) + + reqID := uuid.New().String() + //nolint + ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID) + + log.WithContext(ctx).Tracef("HTTP request %v: %v %v", reqID, r.Method, r.URL) metricKey := getRequestCounterKey(r.URL.Path, r.Method) @@ -175,12 +185,12 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { w := WrapResponseWriter(rw) - h.ServeHTTP(w, r) + h.ServeHTTP(w, r.WithContext(ctx)) if w.Status() > 399 { - log.Errorf("HTTP response %v: %v %v status %v", traceID, r.Method, r.URL, w.Status()) + log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status()) } else { - log.Tracef("HTTP response %v: %v %v status %v", traceID, r.Method, r.URL, w.Status()) + log.WithContext(ctx).Tracef("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status()) } metricKey = getResponseCounterKey(r.URL.Path, r.Method, w.Status()) @@ -198,7 +208,7 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { if c, ok := m.httpRequestDurations[durationKey]; ok { c.Record(m.ctx, reqTook.Milliseconds()) } - log.Debugf("request %s %s took %d ms and finished with status %d", r.Method, r.URL.Path, reqTook.Milliseconds(), w.Status()) + log.WithContext(ctx).Debugf("request %s %s took %d ms and finished with status %d", r.Method, r.URL.Path, reqTook.Milliseconds(), w.Status()) if w.Status() == 200 && (r.Method == http.MethodPut || r.Method == http.MethodPost || r.Method == http.MethodDelete) { opts := metric.WithAttributeSet(attribute.NewSet(attribute.String("type", "write"))) @@ -212,12 +222,3 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { return http.HandlerFunc(fn) } - -func hash(s string) uint32 { - h := fnv.New32a() - _, err := h.Write([]byte(s)) - if err != nil { - panic(err) - } - return h.Sum32() -} diff --git a/management/server/testutil/store.go b/management/server/testutil/store.go index 8db95bd2c..156a762fb 100644 --- a/management/server/testutil/store.go +++ b/management/server/testutil/store.go @@ -33,7 +33,7 @@ func CreatePGDB() (func(), error) { timeout := 10 * time.Second err = c.Stop(ctx, &timeout) if err != nil { - log.Warnf("failed to stop container: %s", err) + log.WithContext(ctx).Warnf("failed to stop container: %s", err) } } diff --git a/management/server/turncredentials.go b/management/server/turncredentials.go index aedcf2ee1..79f42e882 100644 --- a/management/server/turncredentials.go +++ b/management/server/turncredentials.go @@ -1,6 +1,7 @@ package server import ( + "context" "crypto/hmac" "crypto/sha1" "encoding/base64" @@ -16,7 +17,7 @@ import ( // TURNCredentialsManager used to manage TURN credentials type TURNCredentialsManager interface { GenerateCredentials() TURNCredentials - SetupRefresh(peerKey string) + SetupRefresh(ctx context.Context, peerKey string) CancelRefresh(peerKey string) } @@ -81,13 +82,13 @@ func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) { // SetupRefresh starts peer credentials refresh. Since credentials are expiring (TTL) it is necessary to always generate them and send to the peer. // A goroutine is created and put into TimeBasedAuthSecretsManager.cancelMap. This routine should be cancelled if peer is gone. -func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) { +func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, peerID string) { m.mux.Lock() defer m.mux.Unlock() m.cancel(peerID) cancel := make(chan struct{}, 1) m.cancelMap[peerID] = cancel - log.Debugf("starting turn refresh for %s", peerID) + log.WithContext(ctx).Debugf("starting turn refresh for %s", peerID) go func() { // we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL) @@ -96,7 +97,7 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) { for { select { case <-cancel: - log.Debugf("stopping turn refresh for %s", peerID) + log.WithContext(ctx).Debugf("stopping turn refresh for %s", peerID) return case <-ticker.C: c := m.GenerateCredentials() @@ -117,8 +118,8 @@ func (m *TimeBasedAuthSecretsManager) SetupRefresh(peerID string) { Turns: turns, }, } - log.Debugf("sending new TURN credentials to peer %s", peerID) - m.updateManager.SendUpdate(peerID, &UpdateMessage{Update: update}) + log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID) + m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) } } }() diff --git a/management/server/turncredentials_test.go b/management/server/turncredentials_test.go index 5066fdbe9..667dccbb5 100644 --- a/management/server/turncredentials_test.go +++ b/management/server/turncredentials_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "crypto/hmac" "crypto/sha1" "encoding/base64" @@ -46,7 +47,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { secret := "some_secret" peersManager := NewPeersUpdateManager(nil) peer := "some_peer" - updateChannel := peersManager.CreateChannel(peer) + updateChannel := peersManager.CreateChannel(context.Background(), peer) tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{ CredentialsTTL: ttl, @@ -54,7 +55,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { Turns: []*Host{TurnTestHost}, }) - tested.SetupRefresh(peer) + tested.SetupRefresh(context.Background(), peer) if _, ok := tested.cancelMap[peer]; !ok { t.Errorf("expecting peer to be present in a cancel map, got not present") @@ -102,7 +103,7 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) { Turns: []*Host{TurnTestHost}, }) - tested.SetupRefresh(peer) + tested.SetupRefresh(context.Background(), peer) if _, ok := tested.cancelMap[peer]; !ok { t.Errorf("expecting peer to be present in a cancel map, got not present") } diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index f760c5a75..c11225dbc 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -1,6 +1,7 @@ package server import ( + "context" "sync" "time" @@ -35,7 +36,7 @@ func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { } // SendUpdate sends update message to the peer's channel -func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) { +func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) { start := time.Now() var found, dropped bool @@ -51,18 +52,18 @@ func (p *PeersUpdateManager) SendUpdate(peerID string, update *UpdateMessage) { found = true select { case channel <- update: - log.Debugf("update was sent to channel for peer %s", peerID) + log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID) default: dropped = true - log.Warnf("channel for peer %s is %d full", peerID, len(channel)) + log.WithContext(ctx).Warnf("channel for peer %s is %d full", peerID, len(channel)) } } else { - log.Debugf("peer %s has no channel", peerID) + log.WithContext(ctx).Debugf("peer %s has no channel", peerID) } } // CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer. -func (p *PeersUpdateManager) CreateChannel(peerID string) chan *UpdateMessage { +func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage { start := time.Now() closed := false @@ -84,22 +85,22 @@ func (p *PeersUpdateManager) CreateChannel(peerID string) chan *UpdateMessage { channel := make(chan *UpdateMessage, channelBufferSize) p.peerChannels[peerID] = channel - log.Debugf("opened updates channel for a peer %s", peerID) + log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID) return channel } -func (p *PeersUpdateManager) closeChannel(peerID string) { +func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) { if channel, ok := p.peerChannels[peerID]; ok { delete(p.peerChannels, peerID) close(channel) } - log.Debugf("closed updates channel of a peer %s", peerID) + log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID) } // CloseChannels closes updates channel for each given peer -func (p *PeersUpdateManager) CloseChannels(peerIDs []string) { +func (p *PeersUpdateManager) CloseChannels(ctx context.Context, peerIDs []string) { start := time.Now() p.channelsMux.Lock() @@ -111,12 +112,12 @@ func (p *PeersUpdateManager) CloseChannels(peerIDs []string) { }() for _, id := range peerIDs { - p.closeChannel(id) + p.closeChannel(ctx, id) } } // CloseChannel closes updates channel of a given peer -func (p *PeersUpdateManager) CloseChannel(peerID string) { +func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) { start := time.Now() p.channelsMux.Lock() @@ -127,7 +128,7 @@ func (p *PeersUpdateManager) CloseChannel(peerID string) { } }() - p.closeChannel(peerID) + p.closeChannel(ctx, peerID) } // GetAllConnectedPeers returns a copy of the connected peers map diff --git a/management/server/updatechannel_test.go b/management/server/updatechannel_test.go index 187e404c5..69f5b895c 100644 --- a/management/server/updatechannel_test.go +++ b/management/server/updatechannel_test.go @@ -1,20 +1,21 @@ package server import ( + "context" "testing" "time" "github.com/netbirdio/netbird/management/proto" ) -//var peersUpdater *PeersUpdateManager +// var peersUpdater *PeersUpdateManager func TestCreateChannel(t *testing.T) { peer := "test-create" peersUpdater := NewPeersUpdateManager(nil) - defer peersUpdater.CloseChannel(peer) + defer peersUpdater.CloseChannel(context.Background(), peer) - _ = peersUpdater.CreateChannel(peer) + _ = peersUpdater.CreateChannel(context.Background(), peer) if _, ok := peersUpdater.peerChannels[peer]; !ok { t.Error("Error creating the channel") } @@ -28,11 +29,11 @@ func TestSendUpdate(t *testing.T) { Serial: 0, }, }} - _ = peersUpdater.CreateChannel(peer) + _ = peersUpdater.CreateChannel(context.Background(), peer) if _, ok := peersUpdater.peerChannels[peer]; !ok { t.Error("Error creating the channel") } - peersUpdater.SendUpdate(peer, update1) + peersUpdater.SendUpdate(context.Background(), peer, update1) select { case <-peersUpdater.peerChannels[peer]: default: @@ -40,7 +41,7 @@ func TestSendUpdate(t *testing.T) { } for range [channelBufferSize]int{} { - peersUpdater.SendUpdate(peer, update1) + peersUpdater.SendUpdate(context.Background(), peer, update1) } update2 := &UpdateMessage{Update: &proto.SyncResponse{ @@ -49,7 +50,7 @@ func TestSendUpdate(t *testing.T) { }, }} - peersUpdater.SendUpdate(peer, update2) + peersUpdater.SendUpdate(context.Background(), peer, update2) timeout := time.After(5 * time.Second) for range [channelBufferSize]int{} { select { @@ -67,11 +68,11 @@ func TestSendUpdate(t *testing.T) { func TestCloseChannel(t *testing.T) { peer := "test-close" peersUpdater := NewPeersUpdateManager(nil) - _ = peersUpdater.CreateChannel(peer) + _ = peersUpdater.CreateChannel(context.Background(), peer) if _, ok := peersUpdater.peerChannels[peer]; !ok { t.Error("Error creating the channel") } - peersUpdater.CloseChannel(peer) + peersUpdater.CloseChannel(context.Background(), peer) if _, ok := peersUpdater.peerChannels[peer]; ok { t.Error("Error closing the channel") } diff --git a/management/server/user.go b/management/server/user.go index 2be73fa07..302cfccaa 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "strings" "time" @@ -209,11 +210,11 @@ func NewOwnerUser(id string) *User { } // createServiceUser creates a new service user under the given account. -func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID) } @@ -232,16 +233,16 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUs newUserID := uuid.New().String() newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI) - log.Debugf("New User: %v", newUser) + log.WithContext(ctx).Debugf("New User: %v", newUser) account.Users[newUserID] = newUser - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } meta := map[string]any{"name": newUser.ServiceUserName} - am.StoreEvent(initiatorUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta) + am.StoreEvent(ctx, initiatorUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta) return &UserInfo{ ID: newUser.Id, @@ -257,16 +258,16 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUs } // CreateUser creates a new user under the given account. Effectively this is a user invite. -func (am *DefaultAccountManager) CreateUser(accountID, userID string, user *UserInfo) (*UserInfo, error) { +func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, userID string, user *UserInfo) (*UserInfo, error) { if user.IsServiceUser { - return am.createServiceUser(accountID, userID, StrRoleToUserRole(user.Role), user.Name, user.NonDeletable, user.AutoGroups) + return am.createServiceUser(ctx, accountID, userID, StrRoleToUserRole(user.Role), user.Name, user.NonDeletable, user.AutoGroups) } - return am.inviteNewUser(accountID, userID, user) + return am.inviteNewUser(ctx, accountID, userID, user) } // inviteNewUser Invites a USer to a given account and creates reference in datastore -func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite *UserInfo) (*UserInfo, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *UserInfo) (*UserInfo, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() if am.idpManager == nil { @@ -289,7 +290,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite default: } - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID) } @@ -305,13 +306,13 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite } // inviterUser is the one who is inviting the new user - inviterUser, err := am.lookupUserInCache(inviterID, account) + inviterUser, err := am.lookupUserInCache(ctx, inviterID, account) if err != nil || inviterUser == nil { return nil, status.Errorf(status.NotFound, "inviter user with ID %s doesn't exist in IdP", inviterID) } // check if the user is already registered with this email => reject - user, err := am.lookupUserInCacheByEmail(invite.Email, accountID) + user, err := am.lookupUserInCacheByEmail(ctx, invite.Email, accountID) if err != nil { return nil, err } @@ -320,7 +321,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account") } - users, err := am.idpManager.GetUserByEmail(invite.Email) + users, err := am.idpManager.GetUserByEmail(ctx, invite.Email) if err != nil { return nil, err } @@ -329,7 +330,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account") } - idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID, inviterUser.Email) + idpUser, err := am.idpManager.CreateUser(ctx, invite.Email, invite.Name, accountID, inviterUser.Email) if err != nil { return nil, err } @@ -344,33 +345,33 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite } account.Users[idpUser.ID] = newUser - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } - _, err = am.refreshCache(account.Id) + _, err = am.refreshCache(ctx, account.Id) if err != nil { return nil, err } - am.StoreEvent(userID, newUser.Id, accountID, activity.UserInvited, nil) + am.StoreEvent(ctx, userID, newUser.Id, accountID, activity.UserInvited, nil) return newUser.ToUserInfo(idpUser, account.Settings) } // GetUser looks up a user by provided authorization claims. // It will also create an account if didn't exist for this user before. -func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) { - account, _, err := am.GetAccountFromToken(claims) +func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) { + account, _, err := am.GetAccountFromToken(ctx, claims) if err != nil { return nil, fmt.Errorf("failed to get account with token claims %v", err) } - unlock := am.Store.AcquireAccountWriteLock(account.Id) + unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id) defer unlock() - account, err = am.Store.GetAccount(account.Id) + account, err = am.Store.GetAccount(ctx, account.Id) if err != nil { return nil, fmt.Errorf("failed to get an account from store %v", err) } @@ -386,12 +387,12 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) ( err = am.Store.SaveUserLastLogin(account.Id, claims.UserId, claims.LastLogin) if err != nil { - log.Errorf("failed saving user last login: %v", err) + log.WithContext(ctx).Errorf("failed saving user last login: %v", err) } if newLogin { meta := map[string]any{"timestamp": claims.LastLogin} - am.StoreEvent(claims.UserId, claims.UserId, account.Id, activity.DashboardLogin, meta) + am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.DashboardLogin, meta) } return user, nil @@ -399,11 +400,11 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) ( // ListUsers returns lists of all users under the account. // It doesn't populate user information such as email or name. -func (am *DefaultAccountManager) ListUsers(accountID string) ([]*User, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*User, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -416,21 +417,21 @@ func (am *DefaultAccountManager) ListUsers(accountID string) ([]*User, error) { return users, nil } -func (am *DefaultAccountManager) deleteServiceUser(account *Account, initiatorUserID string, targetUser *User) { +func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, account *Account, initiatorUserID string, targetUser *User) { meta := map[string]any{"name": targetUser.ServiceUserName, "created_at": targetUser.CreatedAt} - am.StoreEvent(initiatorUserID, targetUser.Id, account.Id, activity.ServiceUserDeleted, meta) + am.StoreEvent(ctx, initiatorUserID, targetUser.Id, account.Id, activity.ServiceUserDeleted, meta) delete(account.Users, targetUser.Id) } // DeleteUser deletes a user from the given account. -func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, targetUserID string) error { +func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error { if initiatorUserID == targetUserID { return status.Errorf(status.InvalidArgument, "self deletion is not allowed") } - unlock := am.Store.AcquireAccountWriteLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err } @@ -463,43 +464,43 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t return status.Errorf(status.PermissionDenied, "service user is marked as non-deletable") } - am.deleteServiceUser(account, initiatorUserID, targetUser) - return am.Store.SaveAccount(account) + am.deleteServiceUser(ctx, account, initiatorUserID, targetUser) + return am.Store.SaveAccount(ctx, account) } - return am.deleteRegularUser(account, initiatorUserID, targetUserID) + return am.deleteRegularUser(ctx, account, initiatorUserID, targetUserID) } -func (am *DefaultAccountManager) deleteRegularUser(account *Account, initiatorUserID, targetUserID string) error { - tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(account.Id, initiatorUserID, targetUserID) +func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error { + tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID) if err != nil { - log.Errorf("failed to resolve email address: %s", err) + log.WithContext(ctx).Errorf("failed to resolve email address: %s", err) return err } if !isNil(am.idpManager) { // Delete if the user already exists in the IdP.Necessary in cases where a user account // was created where a user account was provisioned but the user did not sign in - _, err = am.idpManager.GetUserDataByID(targetUserID, idp.AppMetadata{WTAccountID: account.Id}) + _, err = am.idpManager.GetUserDataByID(ctx, targetUserID, idp.AppMetadata{WTAccountID: account.Id}) if err == nil { - err = am.deleteUserFromIDP(targetUserID, account.Id) + err = am.deleteUserFromIDP(ctx, targetUserID, account.Id) if err != nil { - log.Debugf("failed to delete user from IDP: %s", targetUserID) + log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID) return err } } else { - log.Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err) + log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err) } } - err = am.deleteUserPeers(initiatorUserID, targetUserID, account) + err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account) if err != nil { return err } u, err := account.FindUser(targetUserID) if err != nil { - log.Errorf("failed to find user %s for deletion, this should never happen: %s", targetUserID, err) + log.WithContext(ctx).Errorf("failed to find user %s for deletion, this should never happen: %s", targetUserID, err) } var tuCreatedAt time.Time @@ -508,20 +509,20 @@ func (am *DefaultAccountManager) deleteRegularUser(account *Account, initiatorUs } delete(account.Users, targetUserID) - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return err } meta := map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt} - am.StoreEvent(initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) + am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) return nil } -func (am *DefaultAccountManager) deleteUserPeers(initiatorUserID string, targetUserID string, account *Account) error { +func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) error { peers, err := account.FindUserPeers(targetUserID) if err != nil { return status.Errorf(status.Internal, "failed to find user peers") @@ -532,25 +533,25 @@ func (am *DefaultAccountManager) deleteUserPeers(initiatorUserID string, targetU peerIDs = append(peerIDs, peer.ID) } - return am.deletePeers(account, peerIDs, initiatorUserID) + return am.deletePeers(ctx, account, peerIDs, initiatorUserID) } // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. -func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() if am.idpManager == nil { return status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites") } - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return status.Errorf(status.NotFound, "account %s doesn't exist", accountID) } // check if the user is already registered with this ID - user, err := am.lookupUserInCache(targetUserID, account) + user, err := am.lookupUserInCache(ctx, targetUserID, account) if err != nil { return err } @@ -565,19 +566,19 @@ func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID st return status.Errorf(status.PreconditionFailed, "can't invite a user with an activated NetBird account") } - err = am.idpManager.InviteUserByID(user.ID) + err = am.idpManager.InviteUserByID(ctx, user.ID) if err != nil { return err } - am.StoreEvent(initiatorUserID, user.ID, accountID, activity.UserInvited, nil) + am.StoreEvent(ctx, initiatorUserID, user.ID, accountID, activity.UserInvited, nil) return nil } // CreatePAT creates a new PAT for the given user -func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() if tokenName == "" { @@ -588,7 +589,7 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID str return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365") } - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -614,23 +615,23 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID str targetUser.PATs[pat.ID] = &pat.PersonalAccessToken - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, status.Errorf(status.Internal, "failed to save account: %v", err) } meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName} - am.StoreEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta) + am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta) return pat, nil } // DeletePAT deletes a specific PAT from a user -func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return status.Errorf(status.NotFound, "account not found: %s", err) } @@ -664,11 +665,11 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID str } meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName} - am.StoreEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta) + am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta) delete(targetUser.PATs, tokenID) - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return status.Errorf(status.Internal, "Failed to save account: %s", err) } @@ -676,11 +677,11 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID str } // GetPAT returns a specific PAT from a user -func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, status.Errorf(status.NotFound, "account not found: %s", err) } @@ -708,11 +709,11 @@ func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string } // GetAllPATs returns all PATs for a user -func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, status.Errorf(status.NotFound, "account not found: %s", err) } @@ -740,21 +741,21 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID st } // SaveUser saves updates to the given user. If the user doesn't exit it will throw status.NotFound error. -func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error) { - return am.SaveOrAddUser(accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound +func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) { + return am.SaveOrAddUser(ctx, accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound } // SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist // Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now. -func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) { - unlock := am.Store.AcquireAccountWriteLock(accountID) +func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) { + unlock := am.Store.AcquireAccountWriteLock(ctx, accountID) defer unlock() if update == nil { return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } - account, err := am.Store.GetAccount(accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -834,8 +835,8 @@ func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string return nil, err } - if err := am.expireAndUpdatePeers(account, blockedPeers); err != nil { - log.Errorf("failed update expired peers: %s", err) + if err := am.expireAndUpdatePeers(ctx, account, blockedPeers); err != nil { + log.WithContext(ctx).Errorf("failed update expired peers: %s", err) return nil, err } } @@ -847,13 +848,13 @@ func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) account.Network.IncSerial() - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return nil, err } - am.updateAccountPeers(account) + am.updateAccountPeers(ctx, account) } else { - if err = am.Store.SaveAccount(account); err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return nil, err } } @@ -861,17 +862,17 @@ func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string defer func() { if oldUser.IsBlocked() != update.IsBlocked() { if update.IsBlocked() { - am.StoreEvent(initiatorUserID, oldUser.Id, accountID, activity.UserBlocked, nil) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserBlocked, nil) } else { - am.StoreEvent(initiatorUserID, oldUser.Id, accountID, activity.UserUnblocked, nil) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserUnblocked, nil) } } switch { case transferedOwnerRole: - am.StoreEvent(initiatorUserID, oldUser.Id, accountID, activity.TransferredOwnerRole, nil) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.TransferredOwnerRole, nil) case oldUser.Role != newUser.Role: - am.StoreEvent(initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role}) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role}) default: } @@ -881,17 +882,17 @@ func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string for _, g := range removedGroups { group := account.GetGroup(g) if group != nil { - am.StoreEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser, + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser, map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) } else { - log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id) + log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id) } } for _, g := range addedGroups { group := account.GetGroup(g) if group != nil { - am.StoreEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser, + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser, map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) } } @@ -899,7 +900,7 @@ func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string }() if !isNil(am.idpManager) && !newUser.IsServiceUser { - userData, err := am.lookupUserInCache(newUser.Id, account) + userData, err := am.lookupUserInCache(ctx, newUser.Id, account) if err != nil { return nil, err } @@ -909,22 +910,22 @@ func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string } // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist -func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string) (*Account, error) { +func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*Account, error) { start := time.Now() - unlock := am.Store.AcquireGlobalLock() + unlock := am.Store.AcquireGlobalLock(ctx) defer unlock() - log.Debugf("Acquired global lock in %s for user %s", time.Since(start), userID) + log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), userID) lowerDomain := strings.ToLower(domain) - account, err := am.Store.GetAccountByUser(userID) + account, err := am.Store.GetAccountByUser(ctx, userID) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - account, err = am.newAccount(userID, lowerDomain) + account, err = am.newAccount(ctx, userID, lowerDomain) if err != nil { return nil, err } - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } @@ -938,7 +939,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string) if account.Domain != lowerDomain && userObj.Role == UserRoleOwner { account.Domain = lowerDomain - err = am.Store.SaveAccount(account) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, status.Errorf(status.Internal, "failed updating account with domain") } @@ -949,8 +950,8 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string) // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return // based on provided user role. -func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) { - account, err := am.Store.GetAccount(accountID) +func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) { + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err } @@ -969,7 +970,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( key := user.IntegrationReference.CacheKey(accountID, user.Id) info, err := am.externalCacheManager.Get(am.ctx, key) if err != nil { - log.Infof("Get ExternalCache for key: %s, error: %s", key, err) + log.WithContext(ctx).Infof("Get ExternalCache for key: %s, error: %s", key, err) users[user.Id] = true continue } @@ -980,12 +981,12 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero()) } } - queriedUsers, err = am.lookupCache(users, accountID) + queriedUsers, err = am.lookupCache(ctx, users, accountID) if err != nil { return nil, err } - log.Debugf("Got %d users from ExternalCache for account %s", len(usersFromIntegration), accountID) - log.Debugf("Got %d users from InternalCache for account %s", len(queriedUsers), accountID) + log.WithContext(ctx).Debugf("Got %d users from ExternalCache for account %s", len(usersFromIntegration), accountID) + log.WithContext(ctx).Debugf("Got %d users from InternalCache for account %s", len(queriedUsers), accountID) queriedUsers = append(queriedUsers, usersFromIntegration...) } @@ -1052,7 +1053,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( } // expireAndUpdatePeers expires all peers of the given user and updates them in the account -func (am *DefaultAccountManager) expireAndUpdatePeers(account *Account, peers []*nbpeer.Peer) error { +func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error { var peerIDs []string for _, peer := range peers { if peer.Status.LoginExpired { @@ -1065,6 +1066,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(account *Account, peers [] return err } am.StoreEvent( + ctx, peer.UserID, peer.ID, account.Id, activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()), ) @@ -1072,34 +1074,34 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(account *Account, peers [] if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service - am.peersUpdateManager.CloseChannels(peerIDs) - am.updateAccountPeers(account) + am.peersUpdateManager.CloseChannels(ctx, peerIDs) + am.updateAccountPeers(ctx, account) } return nil } -func (am *DefaultAccountManager) deleteUserFromIDP(targetUserID, accountID string) error { +func (am *DefaultAccountManager) deleteUserFromIDP(ctx context.Context, targetUserID, accountID string) error { if am.userDeleteFromIDPEnabled { - log.Debugf("user %s deleted from IdP", targetUserID) - err := am.idpManager.DeleteUser(targetUserID) + log.WithContext(ctx).Debugf("user %s deleted from IdP", targetUserID) + err := am.idpManager.DeleteUser(ctx, targetUserID) if err != nil { return fmt.Errorf("failed to delete user %s from IdP: %s", targetUserID, err) } } else { - err := am.idpManager.UpdateUserAppMetadata(targetUserID, idp.AppMetadata{}) + err := am.idpManager.UpdateUserAppMetadata(ctx, targetUserID, idp.AppMetadata{}) if err != nil { return fmt.Errorf("failed to remove user %s app metadata in IdP: %s", targetUserID, err) } } - err := am.removeUserFromCache(accountID, targetUserID) + err := am.removeUserFromCache(ctx, accountID, targetUserID) if err != nil { - log.Errorf("remove user from account (%q) cache failed with error: %v", accountID, err) + log.WithContext(ctx).Errorf("remove user from account (%q) cache failed with error: %v", accountID, err) } return nil } -func (am *DefaultAccountManager) getEmailAndNameOfTargetUser(accountId, initiatorId, targetId string) (string, string, error) { - userInfos, err := am.GetUsersFromAccount(accountId, initiatorId) +func (am *DefaultAccountManager) getEmailAndNameOfTargetUser(ctx context.Context, accountId, initiatorId, targetId string) (string, string, error) { + userInfos, err := am.GetUsersFromAccount(ctx, accountId, initiatorId) if err != nil { return "", "", err } diff --git a/management/server/user_test.go b/management/server/user_test.go index 5edb811c6..99d2792df 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -39,10 +39,10 @@ const ( func TestUser_CreatePAT_ForSameUser(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -52,7 +52,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - pat, err := am.CreatePAT(mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) + pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } @@ -77,13 +77,13 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockTargetUserId] = &User{ Id: mockTargetUserId, IsServiceUser: false, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -93,19 +93,19 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - _, err = am.CreatePAT(mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn) + _, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn) assert.Errorf(t, err, "Creating PAT for different user should thorw error") } func TestUser_CreatePAT_ForServiceUser(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockTargetUserId] = &User{ Id: mockTargetUserId, IsServiceUser: true, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -115,7 +115,7 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - pat, err := am.CreatePAT(mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn) + pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } @@ -125,10 +125,10 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) { func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -138,16 +138,16 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - _, err = am.CreatePAT(mockAccountID, mockUserID, mockUserID, mockTokenName, mockWrongExpiresIn) + _, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockWrongExpiresIn) assert.Errorf(t, err, "Wrong expiration should thorw error") } func TestUser_CreatePAT_WithEmptyName(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -157,14 +157,14 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - _, err = am.CreatePAT(mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn) + _, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn) assert.Errorf(t, err, "Wrong expiration should thorw error") } func TestUser_DeletePAT(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockUserID] = &User{ Id: mockUserID, PATs: map[string]*PersonalAccessToken{ @@ -174,7 +174,7 @@ func TestUser_DeletePAT(t *testing.T) { }, }, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -184,7 +184,7 @@ func TestUser_DeletePAT(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - err = am.DeletePAT(mockAccountID, mockUserID, mockUserID, mockTokenID1) + err = am.DeletePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenID1) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } @@ -196,8 +196,8 @@ func TestUser_DeletePAT(t *testing.T) { func TestUser_GetPAT(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockUserID] = &User{ Id: mockUserID, PATs: map[string]*PersonalAccessToken{ @@ -207,7 +207,7 @@ func TestUser_GetPAT(t *testing.T) { }, }, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -217,7 +217,7 @@ func TestUser_GetPAT(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - pat, err := am.GetPAT(mockAccountID, mockUserID, mockUserID, mockTokenID1) + pat, err := am.GetPAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenID1) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } @@ -228,8 +228,8 @@ func TestUser_GetPAT(t *testing.T) { func TestUser_GetAllPATs(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockUserID] = &User{ Id: mockUserID, PATs: map[string]*PersonalAccessToken{ @@ -243,7 +243,7 @@ func TestUser_GetAllPATs(t *testing.T) { }, }, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -253,7 +253,7 @@ func TestUser_GetAllPATs(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - pats, err := am.GetAllPATs(mockAccountID, mockUserID, mockUserID) + pats, err := am.GetAllPATs(context.Background(), mockAccountID, mockUserID, mockUserID) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } @@ -330,10 +330,10 @@ func validateStruct(s interface{}) (err error) { func TestUser_CreateServiceUser(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -343,7 +343,7 @@ func TestUser_CreateServiceUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - user, err := am.createServiceUser(mockAccountID, mockUserID, mockRole, mockServiceUserName, false, []string{"group1", "group2"}) + user, err := am.createServiceUser(context.Background(), mockAccountID, mockUserID, mockRole, mockServiceUserName, false, []string{"group1", "group2"}) if err != nil { t.Fatalf("Error when creating service user: %s", err) } @@ -360,7 +360,7 @@ func TestUser_CreateServiceUser(t *testing.T) { assert.True(t, user.IsServiceUser) assert.Equal(t, "active", user.Status) - _, err = am.createServiceUser(mockAccountID, mockUserID, UserRoleOwner, mockServiceUserName, false, nil) + _, err = am.createServiceUser(context.Background(), mockAccountID, mockUserID, UserRoleOwner, mockServiceUserName, false, nil) if err == nil { t.Fatal("should return error when creating service user with owner role") } @@ -368,10 +368,10 @@ func TestUser_CreateServiceUser(t *testing.T) { func TestUser_CreateUser_ServiceUser(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -381,7 +381,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - user, err := am.CreateUser(mockAccountID, mockUserID, &UserInfo{ + user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ Name: mockServiceUserName, Role: mockRole, IsServiceUser: true, @@ -407,10 +407,10 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { func TestUser_CreateUser_RegularUser(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -420,7 +420,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - _, err = am.CreateUser(mockAccountID, mockUserID, &UserInfo{ + _, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ Name: mockServiceUserName, Role: mockRole, IsServiceUser: false, @@ -432,10 +432,10 @@ func TestUser_CreateUser_RegularUser(t *testing.T) { func TestUser_InviteNewUser(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -459,7 +459,7 @@ func TestUser_InviteNewUser(t *testing.T) { } idpMock := idp.MockIDP{ - CreateUserFunc: func(email, name, accountID, invitedByEmail string) (*idp.UserData, error) { + CreateUserFunc: func(_ context.Context, email, name, accountID, invitedByEmail string) (*idp.UserData, error) { newData := &idp.UserData{ Email: email, Name: name, @@ -470,7 +470,7 @@ func TestUser_InviteNewUser(t *testing.T) { return newData, nil }, - GetAccountFunc: func(accountId string) ([]*idp.UserData, error) { + GetAccountFunc: func(_ context.Context, accountId string) ([]*idp.UserData, error) { return mockData, nil }, } @@ -478,7 +478,7 @@ func TestUser_InviteNewUser(t *testing.T) { am.idpManager = &idpMock // test if new invite with regular role works - _, err = am.inviteNewUser(mockAccountID, mockUserID, &UserInfo{ + _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ Name: mockServiceUserName, Role: mockRole, Email: "test@teste.com", @@ -489,7 +489,7 @@ func TestUser_InviteNewUser(t *testing.T) { assert.NoErrorf(t, err, "Invite user should not throw error") // test if new invite with owner role fails - _, err = am.inviteNewUser(mockAccountID, mockUserID, &UserInfo{ + _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ Name: mockServiceUserName, Role: string(UserRoleOwner), Email: "test2@teste.com", @@ -532,10 +532,10 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := newStore(t) - account := newAccountWithId(mockAccountID, mockUserID, "") + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockServiceUserID] = tt.serviceUser - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -545,7 +545,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - err = am.DeleteUser(mockAccountID, mockUserID, mockServiceUserID) + err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockServiceUserID) tt.assertErrFunc(t, err, tt.assertErrMessage) if err != nil { @@ -561,10 +561,10 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { func TestUser_DeleteUser_SelfDelete(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -574,7 +574,7 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - err = am.DeleteUser(mockAccountID, mockUserID, mockUserID) + err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, mockUserID) if err == nil { t.Fatalf("failed to prevent self deletion") } @@ -582,8 +582,8 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) { func TestUser_DeleteUser_regularUser(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") targetId := "user2" account.Users[targetId] = &User{ @@ -612,7 +612,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { Role: UserRoleOwner, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -655,7 +655,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - err = am.DeleteUser(mockAccountID, mockUserID, testCase.userID) + err = am.DeleteUser(context.Background(), mockAccountID, mockUserID, testCase.userID) testCase.assertErrFunc(t, err, testCase.assertErrMessage) }) } @@ -664,10 +664,10 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { func TestDefaultAccountManager_GetUser(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -681,7 +681,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { UserId: mockUserID, } - user, err := am.GetUser(claims) + user, err := am.GetUser(context.Background(), claims) if err != nil { t.Fatalf("Error when checking user role: %s", err) } @@ -693,12 +693,12 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { func TestDefaultAccountManager_ListUsers(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users["normal_user1"] = NewRegularUser("normal_user1") account.Users["normal_user2"] = NewRegularUser("normal_user2") - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -708,7 +708,7 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - users, err := am.ListUsers(mockAccountID) + users, err := am.ListUsers(context.Background(), mockAccountID) if err != nil { t.Fatalf("Error when checking user role: %s", err) } @@ -775,12 +775,12 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { store := newStore(t) - account := newAccountWithId(mockAccountID, mockUserID, "") + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI) account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings delete(account.Users, mockUserID) - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -790,7 +790,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - users, err := am.ListUsers(mockAccountID) + users, err := am.ListUsers(context.Background(), mockAccountID) if err != nil { t.Fatalf("Error when checking user role: %s", err) } @@ -806,8 +806,8 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { func TestDefaultAccountManager_ExternalCache(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") externalUser := &User{ Id: "externalUser", Role: UserRoleUser, @@ -819,7 +819,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { } account.Users[externalUser.Id] = externalUser - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -846,7 +846,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}) assert.NoError(t, err) - infos, err := am.GetUsersFromAccount(mockAccountID, mockUserID) + infos, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID) assert.NoError(t, err) assert.Equal(t, 2, len(infos)) var user *UserInfo @@ -870,15 +870,15 @@ func TestUser_IsAdmin(t *testing.T) { func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { store := newStore(t) - defer store.Close() - account := newAccountWithId(mockAccountID, mockUserID, "") + defer store.Close(context.Background()) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockServiceUserID] = &User{ Id: mockServiceUserID, Role: "user", IsServiceUser: true, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -888,7 +888,7 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - users, err := am.GetUsersFromAccount(mockAccountID, mockUserID) + users, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID) if err != nil { t.Fatalf("Error when getting users from account: %s", err) } @@ -898,16 +898,16 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { func TestUser_GetUsersFromAccount_ForUser(t *testing.T) { store := newStore(t) - defer store.Close() + defer store.Close(context.Background()) - account := newAccountWithId(mockAccountID, mockUserID, "") + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockServiceUserID] = &User{ Id: mockServiceUserID, Role: "user", IsServiceUser: true, } - err := store.SaveAccount(account) + err := store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -917,7 +917,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - users, err := am.GetUsersFromAccount(mockAccountID, mockServiceUserID) + users, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockServiceUserID) if err != nil { t.Fatalf("Error when getting users from account: %s", err) } @@ -1069,7 +1069,7 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // create an account and an admin user - account, err := manager.GetOrCreateAccountByUser(ownerUserID, "netbird.io") + account, err := manager.GetOrCreateAccountByUser(context.Background(), ownerUserID, "netbird.io") if err != nil { t.Fatal(err) } @@ -1078,12 +1078,12 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { account.Users[regularUserID] = NewRegularUser(regularUserID) account.Users[adminUserID] = NewAdminUser(adminUserID) account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"} - err = manager.Store.SaveAccount(account) + err = manager.Store.SaveAccount(context.Background(), account) if err != nil { t.Fatal(err) } - updated, err := manager.SaveUser(account.Id, tc.initiatorID, tc.update) + updated, err := manager.SaveUser(context.Background(), account.Id, tc.initiatorID, tc.update) if tc.expectedErr { require.Errorf(t, err, "expecting SaveUser to throw an error") } else { diff --git a/util/file.go b/util/file.go index 0cbfa37ab..2a6182556 100644 --- a/util/file.go +++ b/util/file.go @@ -1,6 +1,7 @@ package util import ( + "context" "encoding/json" "io" "os" @@ -57,7 +58,7 @@ func WriteJson(file string, obj interface{}) error { } // DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file -func DirectWriteJson(file string, obj interface{}) error { +func DirectWriteJson(ctx context.Context, file string, obj interface{}) error { _, _, err := prepareConfigFileDir(file) if err != nil { diff --git a/util/grpc/dialer.go b/util/grpc/dialer.go index 63c56de17..3fba0c84e 100644 --- a/util/grpc/dialer.go +++ b/util/grpc/dialer.go @@ -27,7 +27,6 @@ func WithCustomDialer() grpc.DialOption { } } - conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) if err != nil { log.Errorf("Failed to dial: %s", err) diff --git a/util/log.go b/util/log.go index fda15a541..90ccea48f 100644 --- a/util/log.go +++ b/util/log.go @@ -2,6 +2,7 @@ package util import ( "io" + "os" "path/filepath" log "github.com/sirupsen/logrus" @@ -30,7 +31,11 @@ func InitLog(logLevel string, logPath string) error { log.SetOutput(io.Writer(lumberjackLogger)) } - formatter.SetTextFormatter(log.StandardLogger()) + if os.Getenv("NB_LOG_FORMAT") == "json" { + formatter.SetJSONFormatter(log.StandardLogger()) + } else { + formatter.SetTextFormatter(log.StandardLogger()) + } log.SetLevel(level) return nil }