Add context to throughout the project and update logging (#2209)

propagate context from all the API calls and log request ID, account ID and peer ID

---------

Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
This commit is contained in:
pascal-fischer 2024-07-03 11:33:02 +02:00 committed by GitHub
parent 7cb81f1d70
commit 765aba2c1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
127 changed files with 2936 additions and 2642 deletions

View File

@ -76,7 +76,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil {
t.Fatal(err)
}
@ -87,13 +87,13 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
if err != nil {
return nil, nil
}
iv, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@ -174,7 +174,7 @@ func TestEngine_SSH(t *testing.T) {
t.Fatal(err)
}
//time.Sleep(250 * time.Millisecond)
// time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
@ -1057,7 +1057,7 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir)
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil {
return nil, "", err
}
@ -1068,13 +1068,13 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil {
return nil, "", err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
return nil, "", err
}

View File

@ -108,7 +108,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir)
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil {
return nil, "", err
}
@ -119,13 +119,13 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil {
return nil, "", err
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
return nil, "", err
}

View File

@ -7,6 +7,18 @@ import (
"strings"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/context"
)
type ExecutionContext string
const (
ExecutionContextKey = "executionContext"
HTTPSource ExecutionContext = "HTTP"
GRPCSource ExecutionContext = "GRPC"
SystemSource ExecutionContext = "SYSTEM"
)
// ContextHook is a custom hook for add the source information for the entry
@ -30,6 +42,27 @@ func (hook ContextHook) Levels() []logrus.Level {
func (hook ContextHook) Fire(entry *logrus.Entry) error {
src := hook.parseSrc(entry.Caller.File)
entry.Data["source"] = fmt.Sprintf("%s:%v", src, entry.Caller.Line)
if entry.Context == nil {
return nil
}
source, ok := entry.Context.Value(ExecutionContextKey).(ExecutionContext)
if !ok {
return nil
}
entry.Data["context"] = source
switch source {
case HTTPSource:
addHTTPFields(entry)
case GRPCSource:
addGRPCFields(entry)
case SystemSource:
addSystemFields(entry)
}
return nil
}
@ -59,3 +92,42 @@ func (hook ContextHook) parseSrc(filePath string) string {
file := path.Base(filePath)
return fmt.Sprintf("%s/%s", pkg, file)
}
func addHTTPFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}
}
func addGRPCFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
entry.Data[context.PeerIDKey] = ctxDeviceID
}
}
func addSystemFields(entry *logrus.Entry) {
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
entry.Data[context.RequestIDKey] = ctxReqID
}
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
entry.Data[context.UserIDKey] = ctxInitiatorID
}
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
entry.Data[context.AccountIDKey] = ctxAccountID
}
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
entry.Data[context.PeerIDKey] = ctxDeviceID
}
}

View File

@ -1,6 +1,8 @@
package formatter
import "github.com/sirupsen/logrus"
import (
"github.com/sirupsen/logrus"
)
// SetTextFormatter set the text formatter for given logger.
func SetTextFormatter(logger *logrus.Logger) {
@ -9,6 +11,13 @@ func SetTextFormatter(logger *logrus.Logger) {
logger.AddHook(NewContextHook())
}
// SetJSONFormatter set the JSON formatter for given logger.
func SetJSONFormatter(logger *logrus.Logger) {
logger.Formatter = &logrus.JSONFormatter{}
logger.ReportCaller = true
logger.AddHook(NewContextHook())
}
// SetLogcatFormatter set the logcat formatter for given logger.
func SetLogcatFormatter(logger *logrus.Logger) {
logger.Formatter = NewLogcatFormatter()

3
go.mod
View File

@ -44,7 +44,6 @@ require (
github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.6.0
github.com/google/gopacket v1.1.19
github.com/google/martian/v3 v3.0.0
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
github.com/gopacket/gopacket v1.1.1
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
@ -58,7 +57,7 @@ require (
github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible

7
go.sum
View File

@ -209,8 +209,6 @@ github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSN
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
github.com/google/martian/v3 v3.0.0 h1:pMen7vLs8nvgEYhywH3KDWJIJTeEr2ULsVWHWYHQyBs=
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A=
github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc=
github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
@ -335,8 +333,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd h1:IzGGIJMpz07aPs3R6/4sxZv63JoCMddftLpVodUK+Ec=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
@ -565,7 +563,6 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191004110552-13f9640d40b9/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=

View File

@ -62,7 +62,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil {
t.Fatal(err)
}
@ -70,13 +70,13 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
ia, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
if err != nil {
t.Fatal(err)
}
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@ -20,6 +20,7 @@ import (
"time"
"github.com/google/uuid"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
@ -35,8 +36,10 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/idp"
@ -77,6 +80,10 @@ var (
Short: "start NetBird Management Server",
PreRunE: func(cmd *cobra.Command, args []string) error {
flag.Parse()
//nolint
ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource)
err := util.InitLog(logLevel, logFile)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
@ -85,7 +92,7 @@ var (
// detect whether user specified a port
userPort := cmd.Flag("port").Changed
config, err = loadMgmtConfig(mgmtConfig)
config, err = loadMgmtConfig(ctx, mgmtConfig)
if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
}
@ -116,6 +123,11 @@ var (
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
//nolint
ctx = context.WithValue(ctx, formatter.ExecutionContextKey, formatter.SystemSource)
err := handleRebrand(cmd)
if err != nil {
return fmt.Errorf("failed to migrate files %v", err)
@ -131,11 +143,11 @@ var (
if err != nil {
return err
}
err = appMetrics.Expose(mgmtMetricsPort, "/metrics")
err = appMetrics.Expose(ctx, mgmtMetricsPort, "/metrics")
if err != nil {
return err
}
store, err := server.NewStore(config.StoreConfig.Engine, config.Datadir, appMetrics)
store, err := server.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics)
if err != nil {
return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err)
}
@ -143,7 +155,7 @@ var (
var idpManager idp.Manager
if config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(*config.IdpManagerConfig, appMetrics)
idpManager, err = idp.NewManager(ctx, *config.IdpManagerConfig, appMetrics)
if err != nil {
return fmt.Errorf("failed retrieving a new idp manager with err: %v", err)
}
@ -152,32 +164,32 @@ var (
if disableSingleAccMode {
mgmtSingleAccModeDomain = ""
}
eventStore, key, err := integrations.InitEventStore(config.Datadir, config.DataStoreEncryptionKey)
eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey)
if err != nil {
return fmt.Errorf("failed to initialize database: %s", err)
}
if config.DataStoreEncryptionKey != key {
log.Infof("update config with activity store key")
log.WithContext(ctx).Infof("update config with activity store key")
config.DataStoreEncryptionKey = key
err := updateMgmtConfig(mgmtConfig, config)
err := updateMgmtConfig(ctx, mgmtConfig, config)
if err != nil {
return fmt.Errorf("failed to write out store encryption key: %s", err)
}
}
geo, err := geolocation.NewGeolocation(config.Datadir)
geo, err := geolocation.NewGeolocation(ctx, config.Datadir)
if err != nil {
log.Warnf("could not initialize geo location service: %v, we proceed without geo support", err)
log.WithContext(ctx).Warnf("could not initialize geo location service: %v, we proceed without geo support", err)
} else {
log.Infof("geo location service has been initialized from %s", config.Datadir)
log.WithContext(ctx).Infof("geo location service has been initialized from %s", config.Datadir)
}
integratedPeerValidator, err := integrations.NewIntegratedValidator(eventStore)
integratedPeerValidator, err := integrations.NewIntegratedValidator(ctx, eventStore)
if err != nil {
return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
}
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator)
if err != nil {
return fmt.Errorf("failed to build default manager: %v", err)
@ -188,13 +200,13 @@ var (
trustedPeers := config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
log.Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
log.WithContext(ctx).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
trustedPeers = defaultTrustedPeers
}
trustedHTTPProxies := config.ReverseProxy.TrustedHTTPProxies
trustedProxiesCount := config.ReverseProxy.TrustedHTTPProxiesCount
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
log.Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
log.WithContext(ctx).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
}
realipOpts := []realip.Option{
@ -206,8 +218,8 @@ var (
gRPCOpts := []grpc.ServerOption{
grpc.KeepaliveEnforcementPolicy(kaep),
grpc.KeepaliveParams(kasp),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...)),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...)),
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
}
var certManager *autocert.Manager
@ -224,7 +236,7 @@ var (
} else if config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "" {
tlsConfig, err = loadTLSConfig(config.HttpConfig.CertFile, config.HttpConfig.CertKey)
if err != nil {
log.Errorf("cannot load TLS credentials: %v", err)
log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err)
return err
}
transportCredentials := credentials.NewTLS(tlsConfig)
@ -233,6 +245,7 @@ var (
}
jwtValidator, err := jwtclaims.NewJWTValidator(
ctx,
config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation,
@ -249,26 +262,24 @@ var (
KeysLocation: config.HttpConfig.AuthKeysLocation,
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}
ephemeralManager := server.NewEphemeralManager(store, accountManager)
ephemeralManager.LoadInitialPeers()
ephemeralManager.LoadInitialPeers(ctx)
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager)
srv, err := server.NewServer(ctx, config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager)
if err != nil {
return fmt.Errorf("failed creating gRPC API handler: %v", err)
}
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
installationID, err := getInstallationID(store)
installationID, err := getInstallationID(ctx, store)
if err != nil {
log.Errorf("cannot load TLS credentials: %v", err)
log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err)
return err
}
@ -278,18 +289,18 @@ var (
idpManager = config.IdpManagerConfig.ManagerType
}
metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager, idpManager)
go metricsWorker.Run()
go metricsWorker.Run(ctx)
}
var compatListener net.Listener
if mgmtPort != ManagementLegacyPort {
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
// are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073.
compatListener, err = serveGRPC(gRPCAPIHandler, ManagementLegacyPort)
compatListener, err = serveGRPC(ctx, gRPCAPIHandler, ManagementLegacyPort)
if err != nil {
return err
}
log.Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
log.WithContext(ctx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
}
rootHandler := handlerFunc(gRPCAPIHandler, httpAPIHandler)
@ -306,8 +317,8 @@ var (
if err != nil {
return fmt.Errorf("failed creating TLS listener on port %d: %v", mgmtPort, err)
}
log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String())
serveHTTP(cml, certManager.HTTPHandler(nil))
log.WithContext(ctx).Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String())
serveHTTP(ctx, cml, certManager.HTTPHandler(nil))
}
} else if tlsConfig != nil {
listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", mgmtPort), tlsConfig)
@ -321,14 +332,14 @@ var (
}
}
log.Infof("management server version %s", version.NetbirdVersion())
log.Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String())
serveGRPCWithHTTP(listener, rootHandler, tlsEnabled)
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String())
serveGRPCWithHTTP(ctx, listener, rootHandler, tlsEnabled)
SetupCloseHandler()
<-stopCh
integratedPeerValidator.Stop()
integratedPeerValidator.Stop(ctx)
if geo != nil {
_ = geo.Stop()
}
@ -339,39 +350,68 @@ var (
_ = certManager.Listener().Close()
}
gRPCAPIHandler.Stop()
_ = store.Close()
_ = eventStore.Close()
log.Infof("stopped Management Service")
_ = store.Close(ctx)
_ = eventStore.Close(ctx)
log.WithContext(ctx).Infof("stopped Management Service")
return nil
},
}
)
func notifyStop(msg string) {
func unaryInterceptor(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
reqID := uuid.New().String()
//nolint
ctx = context.WithValue(ctx, formatter.ExecutionContextKey, formatter.GRPCSource)
//nolint
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(ctx, req)
}
func streamInterceptor(
srv interface{},
ss grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
reqID := uuid.New().String()
wrapped := grpcMiddleware.WrapServerStream(ss)
//nolint
ctx := context.WithValue(ss.Context(), formatter.ExecutionContextKey, formatter.GRPCSource)
//nolint
wrapped.WrappedContext = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
return handler(srv, wrapped)
}
func notifyStop(ctx context.Context, msg string) {
select {
case stopCh <- 1:
log.Error(msg)
log.WithContext(ctx).Error(msg)
default:
// stop has been already called, nothing to report
}
}
func getInstallationID(store server.Store) (string, error) {
func getInstallationID(ctx context.Context, store server.Store) (string, error) {
installationID := store.GetInstallationID()
if installationID != "" {
return installationID, nil
}
installationID = strings.ToUpper(uuid.New().String())
err := store.SaveInstallationID(installationID)
err := store.SaveInstallationID(ctx, installationID)
if err != nil {
return "", err
}
return installationID, nil
}
func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) {
func serveGRPC(ctx context.Context, grpcServer *grpc.Server, port int) (net.Listener, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return nil, err
@ -379,22 +419,22 @@ func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) {
go func() {
err := grpcServer.Serve(listener)
if err != nil {
notifyStop(fmt.Sprintf("failed running gRPC server on port %d: %v", port, err))
notifyStop(ctx, fmt.Sprintf("failed running gRPC server on port %d: %v", port, err))
}
}()
return listener, nil
}
func serveHTTP(httpListener net.Listener, handler http.Handler) {
func serveHTTP(ctx context.Context, httpListener net.Listener, handler http.Handler) {
go func() {
err := http.Serve(httpListener, handler)
if err != nil {
notifyStop(fmt.Sprintf("failed running HTTP server: %v", err))
notifyStop(ctx, fmt.Sprintf("failed running HTTP server: %v", err))
}
}()
}
func serveGRPCWithHTTP(listener net.Listener, handler http.Handler, tlsEnabled bool) {
func serveGRPCWithHTTP(ctx context.Context, listener net.Listener, handler http.Handler, tlsEnabled bool) {
go func() {
var err error
if tlsEnabled {
@ -411,7 +451,7 @@ func serveGRPCWithHTTP(listener net.Listener, handler http.Handler, tlsEnabled b
if err != nil {
select {
case stopCh <- 1:
log.Errorf("failed to serve HTTP and gRPC server: %v", err)
log.WithContext(ctx).Errorf("failed to serve HTTP and gRPC server: %v", err)
default:
// stop has been already called, nothing to report
}
@ -431,7 +471,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle
})
}
func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) {
loadedConfig := &server.Config{}
_, err := util.ReadJson(mgmtConfigPath, loadedConfig)
if err != nil {
@ -452,26 +492,26 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
oidcEndpoint := loadedConfig.HttpConfig.OIDCConfigEndpoint
if oidcEndpoint != "" {
// if OIDCConfigEndpoint is specified, we can load DeviceAuthEndpoint and TokenEndpoint automatically
log.Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint)
oidcConfig, err := fetchOIDCConfig(oidcEndpoint)
log.WithContext(ctx).Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint)
oidcConfig, err := fetchOIDCConfig(ctx, oidcEndpoint)
if err != nil {
return nil, err
}
log.Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
log.WithContext(ctx).Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
log.Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
log.WithContext(ctx).Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
oidcConfig.Issuer, loadedConfig.HttpConfig.AuthIssuer)
loadedConfig.HttpConfig.AuthIssuer = oidcConfig.Issuer
log.Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s",
log.WithContext(ctx).Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s",
oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation)
loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(server.NONE)) {
log.Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
log.Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s",
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.DeviceAuthEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint = oidcConfig.DeviceAuthEndpoint
@ -479,7 +519,7 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
if err != nil {
return nil, err
}
log.Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s",
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s",
u.Host, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
@ -489,10 +529,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
}
if loadedConfig.PKCEAuthorizationFlow != nil {
log.Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint)
loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
log.Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s",
log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.AuthorizationEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint)
loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
}
@ -501,8 +541,8 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
return loadedConfig, err
}
func updateMgmtConfig(path string, config *server.Config) error {
return util.DirectWriteJson(path, config)
func updateMgmtConfig(ctx context.Context, path string, config *server.Config) error {
return util.DirectWriteJson(ctx, path, config)
}
// OIDCConfigResponse used for parsing OIDC config response
@ -515,7 +555,7 @@ type OIDCConfigResponse struct {
}
// fetchOIDCConfig fetches OIDC configuration from the IDP
func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
func fetchOIDCConfig(ctx context.Context, oidcEndpoint string) (OIDCConfigResponse, error) {
res, err := http.Get(oidcEndpoint)
if err != nil {
return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration from endpoint %s %v", oidcEndpoint, err)
@ -524,7 +564,7 @@ func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
defer func() {
err := res.Body.Close()
if err != nil {
log.Debugf("failed closing response body %v", err)
log.WithContext(ctx).Debugf("failed closing response body %v", err)
}
}()

View File

@ -1,13 +1,16 @@
package cmd
import (
"context"
"flag"
"fmt"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/util"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/util"
)
var shortUp = "Migrate JSON file store to SQLite store. Please make a backup of the JSON file before running this command."
@ -26,10 +29,13 @@ var upCmd = &cobra.Command{
return fmt.Errorf("failed initializing log %v", err)
}
if err := server.MigrateFileStoreToSqlite(mgmtDataDir); err != nil {
//nolint
ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource)
if err := server.MigrateFileStoreToSqlite(ctx, mgmtDataDir); err != nil {
return err
}
log.Info("Migration finished successfully")
log.WithContext(ctx).Info("Migration finished successfully")
return nil
},

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"crypto/sha256"
b64 "encoding/base64"
"encoding/json"
@ -29,11 +30,11 @@ import (
type MocIntegratedValidator struct {
}
func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil
}
func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
return update, nil
}
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
@ -44,15 +45,15 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[s
return validatedPeers, nil
}
func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
return peer
}
func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
return false, false, nil
}
func (MocIntegratedValidator) PeerDeleted(_, _ string) error {
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
return nil
}
@ -60,7 +61,7 @@ func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)
}
func (MocIntegratedValidator) Stop() {
func (MocIntegratedValidator) Stop(_ context.Context) {
}
func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) {
@ -85,7 +86,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Ac
setupKey = key.Key
}
_, _, _, err := manager.AddPeer(setupKey, userID, peer)
_, _, _, err := manager.AddPeer(context.Background(), setupKey, userID, peer)
if err != nil {
t.Error("expected to add new peer successfully after creating new account, but failed", err)
}
@ -395,7 +396,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}
for _, testCase := range tt {
account := newAccountWithId("account-1", userID, "netbird.io")
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io")
account.UpdateSettings(&testCase.accountSettings)
account.Network = network
account.Peers = testCase.peers
@ -409,7 +410,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
validatedPeers[p] = struct{}{}
}
networkMap := account.GetPeerNetworkMap(testCase.peerID, "netbird.io", validatedPeers)
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, "netbird.io", validatedPeers)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
}
@ -419,7 +420,7 @@ func TestNewAccount(t *testing.T) {
domain := "netbird.io"
userId := "account_creator"
accountID := "account_id"
account := newAccountWithId(accountID, userId, domain)
account := newAccountWithId(context.Background(), accountID, userId, domain)
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
}
@ -430,7 +431,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
return
}
account, err := manager.GetOrCreateAccountByUser(userID, "")
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
if err != nil {
t.Fatal(err)
}
@ -439,7 +440,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
return
}
account, err = manager.Store.GetAccountByUser(userID)
account, err = manager.Store.GetAccountByUser(context.Background(), userID)
if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID)
return
@ -630,11 +631,11 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
initAccount, err := manager.GetAccountByUserOrAccountID(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed")
if testCase.inputUpdateAttrs {
err = manager.updateAccountDomainAttributes(initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
require.NoError(t, err, "update init user failed")
}
@ -642,7 +643,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id
}
account, _, err := manager.GetAccountFromToken(testCase.inputClaims)
account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
require.NoError(t, err, "support function failed")
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
@ -661,12 +662,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
userId := "user-id"
domain := "test.domain"
initAccount := newAccountWithId("", userId, domain)
initAccount := newAccountWithId(context.Background(), "", userId, domain)
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID := initAccount.Id
acc, err := manager.GetAccountByUserOrAccountID(userId, accountID, domain)
acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain)
require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountByUserOrAccountID where the id is getting generated
@ -682,18 +683,18 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
}
t.Run("JWT groups disabled", func(t *testing.T) {
account, _, err := manager.GetAccountFromToken(claims)
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 1, "only ALL group should exists")
})
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
err := manager.Store.SaveAccount(initAccount)
err := manager.Store.SaveAccount(context.Background(), initAccount)
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(claims)
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
})
@ -701,11 +702,11 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
t.Run("JWT groups enabled", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
err := manager.Store.SaveAccount(initAccount)
err := manager.Store.SaveAccount(context.Background(), initAccount)
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(claims)
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
require.Len(t, account.Groups, 3, "groups should be added to the account")
@ -728,7 +729,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
store := newStore(t)
account := newAccountWithId("account_id", "testuser", "")
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
@ -742,7 +743,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
},
},
}
err := store.SaveAccount(account)
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -751,7 +752,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
Store: store,
}
account, user, pat, err := am.GetAccountFromPAT(token)
account, user, pat, err := am.GetAccountFromPAT(context.Background(), token)
if err != nil {
t.Fatalf("Error when getting Account from PAT: %s", err)
}
@ -763,7 +764,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
store := newStore(t)
account := newAccountWithId("account_id", "testuser", "")
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
hashedToken := sha256.Sum256([]byte(token))
@ -778,7 +779,7 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
},
},
}
err := store.SaveAccount(account)
err := store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatalf("Error when saving account: %s", err)
}
@ -787,12 +788,12 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
Store: store,
}
err = am.MarkPATUsed("tokenId")
err = am.MarkPATUsed(context.Background(), "tokenId")
if err != nil {
t.Fatalf("Error when marking PAT used: %s", err)
}
account, err = am.Store.GetAccount("account_id")
account, err = am.Store.GetAccount(context.Background(), "account_id")
if err != nil {
t.Fatalf("Error when getting account: %s", err)
}
@ -807,7 +808,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
}
userId := "test_user"
account, err := manager.GetOrCreateAccountByUser(userId, "")
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "")
if err != nil {
t.Fatal(err)
}
@ -815,7 +816,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
t.Fatalf("expected to create an account for a user %s", userId)
}
account, err = manager.Store.GetAccountByUser(userId)
account, err = manager.Store.GetAccountByUser(context.Background(), userId)
if err != nil {
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
}
@ -834,7 +835,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
userId := "test_user"
domain := "hotmail.com"
account, err := manager.GetOrCreateAccountByUser(userId, domain)
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
if err != nil {
t.Fatal(err)
}
@ -848,7 +849,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
domain = "gmail.com"
account, err = manager.GetOrCreateAccountByUser(userId, domain)
account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
if err != nil {
t.Fatalf("got the following error while retrieving existing acc: %v", err)
}
@ -871,7 +872,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user"
account, err := manager.GetAccountByUserOrAccountID(userId, "", "")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "")
if err != nil {
t.Fatal(err)
}
@ -880,20 +881,20 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
return
}
_, err = manager.GetAccountByUserOrAccountID("", account.Id, "")
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
if err != nil {
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id)
}
_, err = manager.GetAccountByUserOrAccountID("", "", "")
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "")
if err == nil {
t.Errorf("expected an error when user and account IDs are empty")
}
}
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) {
account := newAccountWithId(accountID, userID, domain)
err := am.Store.SaveAccount(account)
account := newAccountWithId(context.Background(), accountID, userID, domain)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err
}
@ -915,7 +916,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
}
// AddAccount has been already tested so we can assume it is correct and compare results
getAccount, err := manager.Store.GetAccount(account.Id)
getAccount, err := manager.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Fatal(err)
return
@ -952,12 +953,12 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
t.Fatal(err)
}
err = manager.DeleteAccount(account.Id, userId)
err = manager.DeleteAccount(context.Background(), account.Id, userId)
if err != nil {
t.Fatal(err)
}
getAccount, err := manager.Store.GetAccount(account.Id)
getAccount, err := manager.Store.GetAccount(context.Background(), account.Id)
if err == nil {
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
}
@ -978,7 +979,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
serial := account.Network.CurrentSerial() // should be 0
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
t.Fatal("error creating setup key")
return
@ -997,7 +998,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
expectedPeerKey := key.PublicKey().String()
expectedSetupKey := setupKey.Key
peer, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
@ -1006,7 +1007,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
return
}
account, err = manager.Store.GetAccount(account.Id)
account, err = manager.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Fatal(err)
return
@ -1045,7 +1046,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
return
}
account, err := manager.GetOrCreateAccountByUser(userID, "netbird.cloud")
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud")
if err != nil {
t.Fatal(err)
}
@ -1065,7 +1066,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
expectedPeerKey := key.PublicKey().String()
expectedUserID := userID
peer, _, _, err := manager.AddPeer("", userID, &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
@ -1074,7 +1075,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
return
}
account, err = manager.Store.GetAccount(account.Id)
account, err = manager.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Fatal(err)
return
@ -1121,7 +1122,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
t.Fatal(err)
}
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
t.Fatal("error creating setup key")
return
@ -1140,7 +1141,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
expectedPeerKey := key.PublicKey().String()
peer, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
@ -1156,14 +1157,14 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
peer2 := getPeer()
peer3 := getPeer()
account, err = manager.Store.GetAccount(account.Id)
account, err = manager.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Fatal(err)
return
}
updMsg := manager.peersUpdateManager.CreateChannel(peer1.ID)
defer manager.peersUpdateManager.CloseChannel(peer1.ID)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
group := group.Group{
ID: "group-id",
@ -1197,7 +1198,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
}()
if err := manager.SaveGroup(account.Id, userID, &group); err != nil {
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err)
return
}
@ -1217,7 +1218,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
}()
if err := manager.DeletePolicy(account.Id, account.Policies[0].ID, userID); err != nil {
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
@ -1237,7 +1238,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
}()
if err := manager.SavePolicy(account.Id, userID, &policy); err != nil {
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
@ -1256,7 +1257,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}
}()
if err := manager.DeletePeer(account.Id, peer3.ID, userID); err != nil {
if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil {
t.Errorf("delete peer: %v", err)
return
}
@ -1277,9 +1278,9 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
}()
// clean policy is pre requirement for delete group
_ = manager.DeletePolicy(account.Id, policy.ID, userID)
_ = manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID)
if err := manager.DeleteGroup(account.Id, "", group.ID); err != nil {
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
t.Errorf("delete group: %v", err)
return
}
@ -1301,7 +1302,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
t.Fatal(err)
}
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
t.Fatal("error creating setup key")
return
@ -1315,7 +1316,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
peerKey := key.PublicKey().String()
peer, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: peerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: peerKey},
})
@ -1324,12 +1325,12 @@ func TestAccountManager_DeletePeer(t *testing.T) {
return
}
err = manager.DeletePeer(account.Id, peerKey, userID)
err = manager.DeletePeer(context.Background(), account.Id, peerKey, userID)
if err != nil {
return
}
account, err = manager.Store.GetAccount(account.Id)
account, err = manager.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Fatal(err)
return
@ -1357,7 +1358,7 @@ func getEvent(t *testing.T, accountID string, manager AccountManager, eventType
case <-time.After(time.Second):
t.Fatal("no PeerAddedWithSetupKey event was generated")
default:
events, err := manager.GetEvents(accountID, userID)
events, err := manager.GetEvents(context.Background(), accountID, userID)
if err != nil {
t.Fatal(err)
}
@ -1389,7 +1390,7 @@ func TestGetUsersFromAccount(t *testing.T) {
account.Users[user.Id] = user
}
userInfos, err := manager.GetUsersFromAccount(accountId, "1")
userInfos, err := manager.GetUsersFromAccount(context.Background(), accountId, "1")
if err != nil {
t.Fatal(err)
}
@ -1500,7 +1501,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
},
}
routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
routes := account.getRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
assert.Len(t, routes, 2)
routeIDs := make(map[route.ID]struct{}, 2)
@ -1510,7 +1511,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
assert.Contains(t, routeIDs, route.ID("route-2"))
assert.Contains(t, routeIDs, route.ID("route-3"))
emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
emptyRoutes := account.getRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
assert.Len(t, emptyRoutes, 0)
}
@ -1645,7 +1646,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
assert.NotNil(t, account.Settings)
@ -1657,23 +1658,23 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer, _, _, err := manager.AddPeer("", userID, &nbpeer.Peer{
peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@ -1682,10 +1683,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
wg := &sync.WaitGroup{}
wg.Add(2)
manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(IDs []string) {
CancelFunc: func(ctx context.Context, IDs []string) {
wg.Done()
},
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done()
},
}
@ -1693,11 +1694,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
// disable expiration first
update := peer.Copy()
update.LoginExpirationEnabled = false
_, err = manager.UpdatePeer(account.Id, userID, update)
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
require.NoError(t, err, "unable to update peer")
// enabling expiration should trigger the routine
update.LoginExpirationEnabled = true
_, err = manager.UpdatePeer(account.Id, userID, update)
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
require.NoError(t, err, "unable to update peer")
failed := waitTimeout(wg, time.Second)
@ -1710,18 +1711,18 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, _, _, err = manager.AddPeer("", userID, &nbpeer.Peer{
_, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@ -1730,18 +1731,18 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
wg := &sync.WaitGroup{}
wg.Add(2)
manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(IDs []string) {
CancelFunc: func(ctx context.Context, IDs []string) {
wg.Done()
},
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done()
},
}
account, err = manager.GetAccountByUserOrAccountID(userID, "", "")
account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@ -1754,35 +1755,35 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, _, _, err = manager.AddPeer("", userID, &nbpeer.Peer{
_, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}
wg.Add(2)
manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(IDs []string) {
CancelFunc: func(ctx context.Context, IDs []string) {
wg.Done()
},
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done()
},
}
// enabling PeerLoginExpirationEnabled should trigger the expiration job
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@ -1795,7 +1796,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
wg.Add(1)
// disabling PeerLoginExpirationEnabled should trigger cancel
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
})
@ -1810,10 +1811,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(account.Id, userID, &Settings{
updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
})
@ -1821,19 +1822,19 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
account, err = manager.GetAccountByUserOrAccountID("", account.Id, "")
account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
require.NoError(t, err, "unable to get account by ID")
assert.False(t, account.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour)
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Second,
PeerLoginExpirationEnabled: false,
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour * 24 * 181,
PeerLoginExpirationEnabled: false,
})
@ -2294,7 +2295,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
}
eventStore := &activity.InMemoryEventStore{}
manager, err := BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
if err != nil {
return nil, err
}
@ -2305,7 +2306,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
func createStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromJson(dataDir)
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
if err != nil {
return nil, err
}

View File

@ -1,6 +1,7 @@
package sqlite
import (
"context"
"database/sql"
"encoding/json"
"fmt"
@ -86,7 +87,7 @@ type Store struct {
}
// NewSQLiteStore creates a new Store with an event table if not exists.
func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (*Store, error) {
dbFile := filepath.Join(dataDir, eventSinkDB)
db, err := sql.Open("sqlite3", dbFile)
if err != nil {
@ -111,7 +112,7 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
return nil, err
}
err = updateDeletedUsersTable(db)
err = updateDeletedUsersTable(ctx, db)
if err != nil {
_ = db.Close()
return nil, err
@ -153,7 +154,7 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
return s, nil
}
func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) {
events := make([]*activity.Event, 0)
var cryptErr error
for result.Next() {
@ -235,14 +236,14 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
}
if cryptErr != nil {
log.Warnf("%s", cryptErr)
log.WithContext(ctx).Warnf("%s", cryptErr)
}
return events, nil
}
// Get returns "limit" number of events from index ordered descending or ascending by a timestamp
func (store *Store) Get(accountID string, offset, limit int, descending bool) ([]*activity.Event, error) {
func (store *Store) Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*activity.Event, error) {
stmt := store.selectDescStatement
if !descending {
stmt = store.selectAscStatement
@ -254,11 +255,11 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([
}
defer result.Close() //nolint
return store.processResult(result)
return store.processResult(ctx, result)
}
// Save an event in the SQLite events table end encrypt the "email" element in meta map
func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
func (store *Store) Save(_ context.Context, event *activity.Event) (*activity.Event, error) {
var jsonMeta string
meta, err := store.saveDeletedUserEmailAndNameInEncrypted(event)
if err != nil {
@ -317,15 +318,15 @@ func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event
}
// Close the Store
func (store *Store) Close() error {
func (store *Store) Close(_ context.Context) error {
if store.db != nil {
return store.db.Close()
}
return nil
}
func updateDeletedUsersTable(db *sql.DB) error {
log.Debugf("check deleted_users table version")
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
log.WithContext(ctx).Debugf("check deleted_users table version")
rows, err := db.Query(`PRAGMA table_info(deleted_users);`)
if err != nil {
return err
@ -360,7 +361,7 @@ func updateDeletedUsersTable(db *sql.DB) error {
return nil
}
log.Debugf("update delted_users table")
log.WithContext(ctx).Debugf("update delted_users table")
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
return err
}

View File

@ -1,6 +1,7 @@
package sqlite
import (
"context"
"fmt"
"testing"
"time"
@ -13,17 +14,17 @@ import (
func TestNewSQLiteStore(t *testing.T) {
dataDir := t.TempDir()
key, _ := GenerateKey()
store, err := NewSQLiteStore(dataDir, key)
store, err := NewSQLiteStore(context.Background(), dataDir, key)
if err != nil {
t.Fatal(err)
return
}
defer store.Close() //nolint
defer store.Close(context.Background()) //nolint
accountID := "account_1"
for i := 0; i < 10; i++ {
_, err = store.Save(&activity.Event{
_, err = store.Save(context.Background(), &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "user_" + fmt.Sprint(i),
@ -36,7 +37,7 @@ func TestNewSQLiteStore(t *testing.T) {
}
}
result, err := store.Get(accountID, 0, 10, false)
result, err := store.Get(context.Background(), accountID, 0, 10, false)
if err != nil {
t.Fatal(err)
return
@ -45,7 +46,7 @@ func TestNewSQLiteStore(t *testing.T) {
assert.Len(t, result, 10)
assert.True(t, result[0].Timestamp.Before(result[len(result)-1].Timestamp))
result, err = store.Get(accountID, 0, 5, true)
result, err = store.Get(context.Background(), accountID, 0, 5, true)
if err != nil {
t.Fatal(err)
return

View File

@ -1,15 +1,18 @@
package activity
import "sync"
import (
"context"
"sync"
)
// Store provides an interface to store or stream events.
type Store interface {
// Save an event in the store
Save(event *Event) (*Event, error)
Save(ctx context.Context, event *Event) (*Event, error)
// Get returns "limit" number of events from the "offset" index ordered descending or ascending by a timestamp
Get(accountID string, offset, limit int, descending bool) ([]*Event, error)
Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*Event, error)
// Close the sink flushing events if necessary
Close() error
Close(ctx context.Context) error
}
// InMemoryEventStore implements the Store interface storing data in-memory
@ -20,7 +23,7 @@ type InMemoryEventStore struct {
}
// Save sets the Event.ID to 1
func (store *InMemoryEventStore) Save(event *Event) (*Event, error) {
func (store *InMemoryEventStore) Save(_ context.Context, event *Event) (*Event, error) {
store.mu.Lock()
defer store.mu.Unlock()
if store.events == nil {
@ -33,7 +36,7 @@ func (store *InMemoryEventStore) Save(event *Event) (*Event, error) {
}
// Get returns a list of ALL events that belong to the given accountID without taking offset, limit and order into consideration
func (store *InMemoryEventStore) Get(accountID string, offset, limit int, descending bool) ([]*Event, error) {
func (store *InMemoryEventStore) Get(_ context.Context, accountID string, offset, limit int, descending bool) ([]*Event, error) {
store.mu.Lock()
defer store.mu.Unlock()
events := make([]*Event, 0)
@ -46,7 +49,7 @@ func (store *InMemoryEventStore) Get(accountID string, offset, limit int, descen
}
// Close cleans up the event list
func (store *InMemoryEventStore) Close() error {
func (store *InMemoryEventStore) Close(_ context.Context) error {
store.mu.Lock()
defer store.mu.Unlock()
store.events = make([]*Event, 0)

View File

@ -0,0 +1,8 @@
package context
const (
RequestIDKey = "requestID"
AccountIDKey = "accountID"
UserIDKey = "userID"
PeerIDKey = "peerID"
)

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"fmt"
"strconv"
@ -34,11 +35,11 @@ func (d DNSSettings) Copy() DNSSettings {
}
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@ -56,11 +57,11 @@ func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string)
}
// SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
@ -89,7 +90,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string
account.DNSSettings = dnsSettingsToSave.Copy()
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
@ -97,17 +98,17 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string
for _, id := range addedGroups {
group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
}
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
for _, id := range removedGroups {
group := account.GetGroup(id)
meta := map[string]any{"group": group.Name, "group_id": group.ID}
am.StoreEvent(userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
}
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
return nil
}
@ -149,9 +150,9 @@ func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
return protoUpdate
}
func getPeersCustomZone(account *Account, dnsDomain string) nbdns.CustomZone {
func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string) nbdns.CustomZone {
if dnsDomain == "" {
log.Errorf("no dns domain is set, returning empty zone")
log.WithContext(ctx).Errorf("no dns domain is set, returning empty zone")
return nbdns.CustomZone{}
}
@ -161,7 +162,7 @@ func getPeersCustomZone(account *Account, dnsDomain string) nbdns.CustomZone {
for _, peer := range account.Peers {
if peer.DNSLabel == "" {
log.Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
log.WithContext(ctx).Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
continue
}
@ -210,14 +211,14 @@ func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
return false
}
func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) {
func addPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels lookupMap) {
for _, peer := range account.Peers {
label, err := getPeerHostLabel(peer.Name, peerLabels)
if err != nil {
log.Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err)
log.WithContext(ctx).Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err)
label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels)
if err != nil {
log.Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err)
log.WithContext(ctx).Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err)
continue
}
}

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"net/netip"
"testing"
@ -35,7 +36,7 @@ func TestGetDNSSettings(t *testing.T) {
t.Fatal("failed to init testing account")
}
dnsSettings, err := am.GetDNSSettings(account.Id, dnsAdminUserID)
dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
if err != nil {
t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
}
@ -48,12 +49,12 @@ func TestGetDNSSettings(t *testing.T) {
DisabledManagementGroups: []string{group1ID},
}
err = am.Store.SaveAccount(account)
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
t.Error("failed to save testing account with new DNS settings")
}
dnsSettings, err = am.GetDNSSettings(account.Id, dnsAdminUserID)
dnsSettings, err = am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
if err != nil {
t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
}
@ -62,7 +63,7 @@ func TestGetDNSSettings(t *testing.T) {
t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups)
}
_, err = am.GetDNSSettings(account.Id, dnsRegularUserID)
_, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID)
if err == nil {
t.Errorf("An error should be returned when getting the DNS settings with a regular user")
}
@ -122,7 +123,7 @@ func TestSaveDNSSettings(t *testing.T) {
t.Error("failed to init testing account")
}
err = am.SaveDNSSettings(account.Id, testCase.userID, testCase.inputSettings)
err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings)
if err != nil {
if testCase.shouldFail {
return
@ -130,7 +131,7 @@ func TestSaveDNSSettings(t *testing.T) {
t.Error(err)
}
updatedAccount, err := am.Store.GetAccount(account.Id)
updatedAccount, err := am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Errorf("should be able to retrieve updated account, got err: %s", err)
}
@ -164,7 +165,7 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
t.Error("failed to init testing account")
}
newAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
newAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
require.NoError(t, err)
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers")
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
@ -173,14 +174,14 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
dnsSettings := account.DNSSettings.Copy()
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
account.DNSSettings = dnsSettings
err = am.Store.SaveAccount(account)
err = am.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
updatedAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
require.NoError(t, err)
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group")
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group")
peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID)
peer2AccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer2.ID)
require.NoError(t, err)
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group")
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group")
@ -194,13 +195,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
}
func createDNSStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromJson(dataDir)
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
if err != nil {
return nil, err
}
@ -244,28 +245,28 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
domain := "example.com"
account := newAccountWithId(dnsAccountID, dnsAdminUserID, domain)
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain)
account.Users[dnsRegularUserID] = &User{
Id: dnsRegularUserID,
Role: UserRoleUser,
}
err := am.Store.SaveAccount(account)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err
}
savedPeer1, _, _, err := am.AddPeer("", dnsAdminUserID, peer1)
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1)
if err != nil {
return nil, err
}
_, _, _, err = am.AddPeer("", dnsAdminUserID, peer2)
_, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2)
if err != nil {
return nil, err
}
account, err = am.Store.GetAccount(account.Id)
account, err = am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
return nil, err
}
@ -312,10 +313,10 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
Groups: []string{allGroup.ID},
}
err = am.Store.SaveAccount(account)
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err
}
return am.Store.GetAccount(account.Id)
return am.Store.GetAccount(context.Background(), account.Id)
}

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"sync"
"time"
@ -51,13 +52,15 @@ func NewEphemeralManager(store Store, accountManager AccountManager) *EphemeralM
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
// head.
func (e *EphemeralManager) LoadInitialPeers() {
func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) {
e.peersLock.Lock()
defer e.peersLock.Unlock()
e.loadEphemeralPeers()
e.loadEphemeralPeers(ctx)
if e.headPeer != nil {
e.timer = time.AfterFunc(ephemeralLifeTime, e.cleanup)
e.timer = time.AfterFunc(ephemeralLifeTime, func() {
e.cleanup(ctx)
})
}
}
@ -73,12 +76,12 @@ func (e *EphemeralManager) Stop() {
// OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer
// is active the manager will not delete it while it is active.
func (e *EphemeralManager) OnPeerConnected(peer *nbpeer.Peer) {
func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Peer) {
if !peer.Ephemeral {
return
}
log.Tracef("remove peer from ephemeral list: %s", peer.ID)
log.WithContext(ctx).Tracef("remove peer from ephemeral list: %s", peer.ID)
e.peersLock.Lock()
defer e.peersLock.Unlock()
@ -94,16 +97,16 @@ func (e *EphemeralManager) OnPeerConnected(peer *nbpeer.Peer) {
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
// is inactive it will be deleted after the ephemeralLifeTime period.
func (e *EphemeralManager) OnPeerDisconnected(peer *nbpeer.Peer) {
func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) {
if !peer.Ephemeral {
return
}
log.Tracef("add peer to ephemeral list: %s", peer.ID)
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
a, err := e.store.GetAccountByPeerID(peer.ID)
a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID)
if err != nil {
log.Errorf("failed to add peer to ephemeral list: %s", err)
log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err)
return
}
@ -116,12 +119,14 @@ func (e *EphemeralManager) OnPeerDisconnected(peer *nbpeer.Peer) {
e.addPeer(peer.ID, a, newDeadLine())
if e.timer == nil {
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup)
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
e.cleanup(ctx)
})
}
}
func (e *EphemeralManager) loadEphemeralPeers() {
accounts := e.store.GetAllAccounts()
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
accounts := e.store.GetAllAccounts(context.Background())
t := newDeadLine()
count := 0
for _, a := range accounts {
@ -132,10 +137,10 @@ func (e *EphemeralManager) loadEphemeralPeers() {
}
}
}
log.Debugf("loaded ephemeral peer(s): %d", count)
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count)
}
func (e *EphemeralManager) cleanup() {
func (e *EphemeralManager) cleanup(ctx context.Context) {
log.Tracef("on ephemeral cleanup")
deletePeers := make(map[string]*ephemeralPeer)
@ -154,7 +159,9 @@ func (e *EphemeralManager) cleanup() {
}
if e.headPeer != nil {
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup)
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
e.cleanup(ctx)
})
} else {
e.timer = nil
}
@ -162,10 +169,10 @@ func (e *EphemeralManager) cleanup() {
e.peersLock.Unlock()
for id, p := range deletePeers {
log.Debugf("delete ephemeral peer: %s", id)
err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator)
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator)
if err != nil {
log.Errorf("failed to delete ephemeral peer: %s", err)
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
}
}
}

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"fmt"
"testing"
"time"
@ -13,11 +14,11 @@ type MockStore struct {
account *Account
}
func (s *MockStore) GetAllAccounts() []*Account {
func (s *MockStore) GetAllAccounts(_ context.Context) []*Account {
return []*Account{s.account}
}
func (s *MockStore) GetAccountByPeerID(peerId string) (*Account, error) {
func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) {
_, ok := s.account.Peers[peerId]
if ok {
return s.account, nil
@ -31,7 +32,7 @@ type MocAccountManager struct {
store *MockStore
}
func (a MocAccountManager) DeletePeer(accountID, peerID, userID string) error {
func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
delete(a.store.account.Peers, peerID)
return nil //nolint:nil
}
@ -52,9 +53,9 @@ func TestNewManager(t *testing.T) {
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers()
mgr.loadEphemeralPeers(context.Background())
startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup()
mgr.cleanup(context.Background())
if len(store.account.Peers) != numberOfPeers {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers))
@ -77,11 +78,11 @@ func TestNewManagerPeerConnected(t *testing.T) {
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers()
mgr.OnPeerConnected(store.account.Peers["ephemeral_peer_0"])
mgr.loadEphemeralPeers(context.Background())
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup()
mgr.cleanup(context.Background())
expected := numberOfPeers + 1
if len(store.account.Peers) != expected {
@ -105,15 +106,15 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
mgr := NewEphemeralManager(store, am)
mgr.loadEphemeralPeers()
mgr.loadEphemeralPeers(context.Background())
for _, v := range store.account.Peers {
mgr.OnPeerConnected(v)
mgr.OnPeerConnected(context.Background(), v)
}
mgr.OnPeerDisconnected(store.account.Peers["ephemeral_peer_0"])
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeralLifeTime + 1)
mgr.cleanup()
mgr.cleanup(context.Background())
expected := numberOfPeers + numberOfEphemeralPeers - 1
if len(store.account.Peers) != expected {
@ -122,7 +123,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
}
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
store.account = newAccountWithId("my account", "", "")
store.account = newAccountWithId(context.Background(), "my account", "", "")
for i := 0; i < numberOfPeers; i++ {
peerId := fmt.Sprintf("peer_%d", i)

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"fmt"
"time"
@ -11,11 +12,11 @@ import (
)
// GetEvents returns a list of activity events of an account
func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@ -29,7 +30,7 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view events")
}
events, err := am.eventStore.Get(accountID, 0, 10000, true)
events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true)
if err != nil {
return nil, err
}
@ -54,10 +55,10 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit
return filtered, nil
}
func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
func (am *DefaultAccountManager) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
go func() {
_, err := am.eventStore.Save(&activity.Event{
_, err := am.eventStore.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activityID,
InitiatorID: initiatorID,
@ -67,7 +68,7 @@ func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID str
})
if err != nil {
// todo add metric
log.Errorf("received an error while storing an activity event, error: %s", err)
log.WithContext(ctx).Errorf("received an error while storing an activity event, error: %s", err)
}
}()

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"testing"
"time"
@ -13,7 +14,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac
accountID string, count int) {
t.Helper()
for i := 0; i < count; i++ {
_, err := manager.eventStore.Save(&activity.Event{
_, err := manager.eventStore.Save(context.Background(), &activity.Event{
Timestamp: time.Now().UTC(),
Activity: typ,
InitiatorID: initiatorID,
@ -35,32 +36,32 @@ func TestDefaultAccountManager_GetEvents(t *testing.T) {
accountID := "accountID"
t.Run("get empty events list", func(t *testing.T) {
events, err := manager.GetEvents(accountID, userID)
events, err := manager.GetEvents(context.Background(), accountID, userID)
if err != nil {
return
}
assert.Len(t, events, 0)
_ = manager.eventStore.Close() //nolint
_ = manager.eventStore.Close(context.Background()) //nolint
})
t.Run("get events", func(t *testing.T) {
generateAndStoreEvents(t, manager, activity.PeerAddedByUser, userID, "peer", accountID, 10)
events, err := manager.GetEvents(accountID, userID)
events, err := manager.GetEvents(context.Background(), accountID, userID)
if err != nil {
return
}
assert.Len(t, events, 10)
_ = manager.eventStore.Close() //nolint
_ = manager.eventStore.Close(context.Background()) //nolint
})
t.Run("get events without duplicates", func(t *testing.T) {
generateAndStoreEvents(t, manager, activity.UserJoined, userID, "", accountID, 10)
events, err := manager.GetEvents(accountID, userID)
events, err := manager.GetEvents(context.Background(), accountID, userID)
if err != nil {
return
}
assert.Len(t, events, 1)
_ = manager.eventStore.Close() //nolint
_ = manager.eventStore.Close(context.Background()) //nolint
})
}

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"os"
"path/filepath"
"strings"
@ -48,8 +49,8 @@ type FileStore struct {
type StoredAccount struct{}
// NewFileStore restores a store from the file located in the datadir
func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
fs, err := restore(filepath.Join(dataDir, storeFileName))
func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
fs, err := restore(ctx, filepath.Join(dataDir, storeFileName))
if err != nil {
return nil, err
}
@ -58,27 +59,27 @@ func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, err
}
// NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir
func NewFilestoreFromSqliteStore(sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
store, err := NewFileStore(dataDir, metrics)
func NewFilestoreFromSqliteStore(ctx context.Context, sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
store, err := NewFileStore(ctx, dataDir, metrics)
if err != nil {
return nil, err
}
err = store.SaveInstallationID(sqlStore.GetInstallationID())
err = store.SaveInstallationID(ctx, sqlStore.GetInstallationID())
if err != nil {
return nil, err
}
for _, account := range sqlStore.GetAllAccounts() {
for _, account := range sqlStore.GetAllAccounts(ctx) {
store.Accounts[account.Id] = account
}
return store, store.persist(store.storeFile)
return store, store.persist(ctx, store.storeFile)
}
// restore the state of the store from the file.
// Creates a new empty store file if doesn't exist
func restore(file string) (*FileStore, error) {
func restore(ctx context.Context, file string) (*FileStore, error) {
if _, err := os.Stat(file); os.IsNotExist(err) {
// create a new FileStore if previously didn't exist (e.g. first run)
s := &FileStore{
@ -95,7 +96,7 @@ func restore(file string) (*FileStore, error) {
storeFile: file,
}
err = s.persist(file)
err = s.persist(ctx, file)
if err != nil {
return nil, err
}
@ -165,7 +166,7 @@ func restore(file string) (*FileStore, error) {
// for data migration. Can be removed once most base will be with labels
existingLabels := account.getPeerDNSLabels()
if len(existingLabels) != len(account.Peers) {
addPeerLabelsToAccount(account, existingLabels)
addPeerLabelsToAccount(ctx, account, existingLabels)
}
// TODO: delete this block after migration
@ -178,7 +179,7 @@ func restore(file string) (*FileStore, error) {
allGroup, err := account.GetGroupAll()
if err != nil {
log.Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
// if the All group didn't exist we probably don't have routes to update
continue
}
@ -236,7 +237,7 @@ func restore(file string) (*FileStore, error) {
}
// we need this persist to apply changes we made to account.Peers (we set them to Disconnected)
err = store.persist(store.storeFile)
err = store.persist(ctx, store.storeFile)
if err != nil {
return nil, err
}
@ -246,7 +247,7 @@ func restore(file string) (*FileStore, error) {
// persist account data to a file
// It is recommended to call it with locking FileStore.mux
func (s *FileStore) persist(file string) error {
func (s *FileStore) persist(ctx context.Context, file string) error {
start := time.Now()
err := util.WriteJson(file, s)
if err != nil {
@ -256,23 +257,23 @@ func (s *FileStore) persist(file string) error {
if s.metrics != nil {
s.metrics.StoreMetrics().CountPersistenceDuration(took)
}
log.Debugf("took %d ms to persist the FileStore", took.Milliseconds())
log.WithContext(ctx).Debugf("took %d ms to persist the FileStore", took.Milliseconds())
return nil
}
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *FileStore) AcquireGlobalLock() (unlock func()) {
log.Debugf("acquiring global lock")
func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
log.WithContext(ctx).Debugf("acquiring global lock")
start := time.Now()
s.globalAccountLock.Lock()
unlock = func() {
s.globalAccountLock.Unlock()
log.Debugf("released global lock in %v", time.Since(start))
log.WithContext(ctx).Debugf("released global lock in %v", time.Since(start))
}
took := time.Since(start)
log.Debugf("took %v to acquire global lock", took)
log.WithContext(ctx).Debugf("took %v to acquire global lock", took)
if s.metrics != nil {
s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took)
}
@ -281,8 +282,8 @@ func (s *FileStore) AcquireGlobalLock() (unlock func()) {
}
// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock
func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
log.Debugf("acquiring lock for account %s", accountID)
func (s *FileStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) {
log.WithContext(ctx).Debugf("acquiring lock for account %s", accountID)
start := time.Now()
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
mtx := value.(*sync.Mutex)
@ -290,7 +291,7 @@ func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
unlock = func() {
mtx.Unlock()
log.Debugf("released lock for account %s in %v", accountID, time.Since(start))
log.WithContext(ctx).Debugf("released lock for account %s in %v", accountID, time.Since(start))
}
return unlock
@ -298,11 +299,11 @@ func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock
// This method is still returns a write lock as file store can't handle read locks
func (s *FileStore) AcquireAccountReadLock(accountID string) (unlock func()) {
return s.AcquireAccountWriteLock(accountID)
func (s *FileStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) {
return s.AcquireAccountWriteLock(ctx, accountID)
}
func (s *FileStore) SaveAccount(account *Account) error {
func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error {
s.mux.Lock()
defer s.mux.Unlock()
@ -338,10 +339,10 @@ func (s *FileStore) SaveAccount(account *Account) error {
s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id
}
return s.persist(s.storeFile)
return s.persist(ctx, s.storeFile)
}
func (s *FileStore) DeleteAccount(account *Account) error {
func (s *FileStore) DeleteAccount(ctx context.Context, account *Account) error {
s.mux.Lock()
defer s.mux.Unlock()
@ -373,7 +374,7 @@ func (s *FileStore) DeleteAccount(account *Account) error {
delete(s.Accounts, account.Id)
return s.persist(s.storeFile)
return s.persist(ctx, s.storeFile)
}
// DeleteHashedPAT2TokenIDIndex removes an entry from the indexing map HashedPAT2TokenID
@ -397,7 +398,7 @@ func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error {
}
// GetAccountByPrivateDomain returns account by private domain
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
func (s *FileStore) GetAccountByPrivateDomain(_ context.Context, domain string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
@ -415,7 +416,7 @@ func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
}
// GetAccountBySetupKey returns account by setup key id
func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
@ -433,7 +434,7 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
}
// GetTokenIDByHashedToken returns the id of a personal access token by its hashed secret
func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) {
func (s *FileStore) GetTokenIDByHashedToken(_ context.Context, token string) (string, error) {
s.mux.Lock()
defer s.mux.Unlock()
@ -446,7 +447,7 @@ func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) {
}
// GetUserByTokenID returns a User object a tokenID belongs to
func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) {
func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User, error) {
s.mux.Lock()
defer s.mux.Unlock()
@ -469,7 +470,7 @@ func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) {
}
// GetAllAccounts returns all accounts
func (s *FileStore) GetAllAccounts() (all []*Account) {
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
s.mux.Lock()
defer s.mux.Unlock()
for _, a := range s.Accounts {
@ -490,7 +491,7 @@ func (s *FileStore) getAccount(accountID string) (*Account, error) {
}
// GetAccount returns an account for ID
func (s *FileStore) GetAccount(accountID string) (*Account, error) {
func (s *FileStore) GetAccount(_ context.Context, accountID string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
@ -503,7 +504,7 @@ func (s *FileStore) GetAccount(accountID string) (*Account, error) {
}
// GetAccountByUser returns a user account
func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
func (s *FileStore) GetAccountByUser(_ context.Context, userID string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
@ -521,7 +522,7 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
}
// GetAccountByPeerID returns an account for a given peer ID
func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
func (s *FileStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
@ -539,7 +540,7 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
// check Account.Peers for a match
if _, ok := account.Peers[peerID]; !ok {
delete(s.PeerID2AccountID, peerID)
log.Warnf("removed stale peerID %s to accountID %s index", peerID, accountID)
log.WithContext(ctx).Warnf("removed stale peerID %s to accountID %s index", peerID, accountID)
return nil, status.NewPeerNotFoundError(peerID)
}
@ -547,7 +548,7 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
}
// GetAccountByPeerPubKey returns an account for a given peer WireGuard public key
func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
func (s *FileStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
s.mux.Lock()
defer s.mux.Unlock()
@ -572,14 +573,14 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
}
if stale {
delete(s.PeerKeyID2AccountID, peerKey)
log.Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID)
log.WithContext(ctx).Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID)
return nil, status.NewPeerNotFoundError(peerKey)
}
return account.Copy(), nil
}
func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
func (s *FileStore) GetAccountIDByPeerPubKey(_ context.Context, peerKey string) (string, error) {
s.mux.Lock()
defer s.mux.Unlock()
@ -603,7 +604,7 @@ func (s *FileStore) GetAccountIDByUserID(userID string) (string, error) {
return accountID, nil
}
func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) {
func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (string, error) {
s.mux.Lock()
defer s.mux.Unlock()
@ -615,7 +616,7 @@ func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) {
return accountID, nil
}
func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) {
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) {
s.mux.Lock()
defer s.mux.Unlock()
@ -638,7 +639,7 @@ func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) {
return nil, status.NewPeerNotFoundError(peerKey)
}
func (s *FileStore) GetAccountSettings(accountID string) (*Settings, error) {
func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) {
s.mux.Lock()
defer s.mux.Unlock()
@ -656,13 +657,13 @@ func (s *FileStore) GetInstallationID() string {
}
// SaveInstallationID saves the installation ID
func (s *FileStore) SaveInstallationID(ID string) error {
func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error {
s.mux.Lock()
defer s.mux.Unlock()
s.InstallationID = ID
return s.persist(s.storeFile)
return s.persist(ctx, s.storeFile)
}
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
@ -732,13 +733,13 @@ func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *
}
// Close the FileStore persisting data to disk
func (s *FileStore) Close() error {
func (s *FileStore) Close(ctx context.Context) error {
s.mux.Lock()
defer s.mux.Unlock()
log.Infof("closing FileStore")
log.WithContext(ctx).Infof("closing FileStore")
return s.persist(s.storeFile)
return s.persist(ctx, s.storeFile)
}
// GetStoreEngine returns FileStoreEngine

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"crypto/sha256"
"net"
"path/filepath"
@ -27,12 +28,12 @@ func TestStalePeerIndices(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
return
}
account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
peerID := "some_peer"
@ -42,24 +43,24 @@ func TestStalePeerIndices(t *testing.T) {
Key: peerKey,
}
err = store.SaveAccount(account)
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
account.DeletePeer(peerID)
err = store.SaveAccount(account)
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
_, err = store.GetAccountByPeerID(peerID)
_, err = store.GetAccountByPeerID(context.Background(), peerID)
require.Error(t, err, "expecting to get an error when found stale index")
_, err = store.GetAccountByPeerPubKey(peerKey)
_, err = store.GetAccountByPeerPubKey(context.Background(), peerKey)
require.Error(t, err, "expecting to get an error when found stale index")
}
func TestNewStore(t *testing.T) {
store := newStore(t)
defer store.Close()
defer store.Close(context.Background())
if store.Accounts == nil || len(store.Accounts) != 0 {
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
@ -88,9 +89,9 @@ func TestNewStore(t *testing.T) {
func TestSaveAccount(t *testing.T) {
store := newStore(t)
defer store.Close()
defer store.Close(context.Background())
account := newAccountWithId("account_id", "testuser", "")
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &nbpeer.Peer{
@ -103,7 +104,7 @@ func TestSaveAccount(t *testing.T) {
}
// SaveAccount should trigger persist
err := store.SaveAccount(account)
err := store.SaveAccount(context.Background(), account)
if err != nil {
return
}
@ -133,11 +134,11 @@ func TestDeleteAccount(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
t.Fatal(err)
}
defer store.Close()
defer store.Close(context.Background())
var account *Account
for _, a := range store.Accounts {
@ -147,7 +148,7 @@ func TestDeleteAccount(t *testing.T) {
require.NotNil(t, account, "failed to restore a FileStore file and get at least one account")
err = store.DeleteAccount(account)
err = store.DeleteAccount(context.Background(), account)
require.NoError(t, err, "failed to delete account, error: %v", err)
_, ok := store.Accounts[account.Id]
@ -183,9 +184,9 @@ func TestDeleteAccount(t *testing.T) {
func TestStore(t *testing.T) {
store := newStore(t)
defer store.Close()
defer store.Close(context.Background())
account := newAccountWithId("account_id", "testuser", "")
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey",
SetupKey: "peerkeysetupkey",
@ -228,12 +229,12 @@ func TestStore(t *testing.T) {
})
// SaveAccount should trigger persist
err := store.SaveAccount(account)
err := store.SaveAccount(context.Background(), account)
if err != nil {
return
}
restored, err := NewFileStore(store.storeFile, nil)
restored, err := NewFileStore(context.Background(), store.storeFile, nil)
if err != nil {
return
}
@ -281,7 +282,7 @@ func TestRestore(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
return
}
@ -319,7 +320,7 @@ func TestRestoreGroups_Migration(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
return
}
@ -332,11 +333,11 @@ func TestRestoreGroups_Migration(t *testing.T) {
Name: "All",
},
}
err = store.SaveAccount(account)
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err, "failed to save account")
// restore account with default group with empty Issue field
if store, err = NewFileStore(storeDir, nil); err != nil {
if store, err = NewFileStore(context.Background(), storeDir, nil); err != nil {
return
}
account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
@ -353,18 +354,18 @@ func TestGetAccountByPrivateDomain(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
return
}
existingDomain := "test.com"
account, err := store.GetAccountByPrivateDomain(existingDomain)
account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain)
require.NoError(t, err, "should found account")
require.Equal(t, existingDomain, account.Domain, "domains should match")
_, err = store.GetAccountByPrivateDomain("missing-domain.com")
_, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com")
require.Error(t, err, "should return error on domain lookup")
}
@ -382,7 +383,7 @@ func TestFileStore_GetAccount(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
t.Fatal(err)
}
@ -393,7 +394,7 @@ func TestFileStore_GetAccount(t *testing.T) {
return
}
account, err := store.GetAccount(expected.Id)
account, err := store.GetAccount(context.Background(), expected.Id)
if err != nil {
t.Fatal(err)
}
@ -424,13 +425,13 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
t.Fatal(err)
}
hashedToken := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].HashedToken
tokenID, err := store.GetTokenIDByHashedToken(hashedToken)
tokenID, err := store.GetTokenIDByHashedToken(context.Background(), hashedToken)
if err != nil {
t.Fatal(err)
}
@ -441,7 +442,7 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) {
func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) {
store := newStore(t)
defer store.Close()
defer store.Close(context.Background())
store.HashedPAT2TokenID["someHashedToken"] = "someTokenId"
err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken")
@ -478,13 +479,13 @@ func TestFileStore_GetTokenIDByHashedToken_Failure(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
t.Fatal(err)
}
wrongToken := sha256.Sum256([]byte("someNotValidTokenThatFails1234"))
_, err = store.GetTokenIDByHashedToken(string(wrongToken[:]))
_, err = store.GetTokenIDByHashedToken(context.Background(), string(wrongToken[:]))
assert.Error(t, err, "GetTokenIDByHashedToken should throw error if token invalid")
}
@ -503,13 +504,13 @@ func TestFileStore_GetUserByTokenID(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
t.Fatal(err)
}
tokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID
user, err := store.GetUserByTokenID(tokenID)
user, err := store.GetUserByTokenID(context.Background(), tokenID)
if err != nil {
t.Fatal(err)
}
@ -531,13 +532,13 @@ func TestFileStore_GetUserByTokenID_Failure(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
t.Fatal(err)
}
wrongTokenID := "someNonExistingTokenID"
_, err = store.GetUserByTokenID(wrongTokenID)
_, err = store.GetUserByTokenID(context.Background(), wrongTokenID)
assert.Error(t, err, "GetUserByTokenID should throw error if tokenID invalid")
}
@ -550,7 +551,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
return
}
@ -576,7 +577,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) {
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
}
err = store.SaveAccount(account)
err = store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatal(err)
}
@ -602,11 +603,11 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
t.Fatal(err)
}
store, err := NewFileStore(storeDir, nil)
store, err := NewFileStore(context.Background(), storeDir, nil)
if err != nil {
return
}
account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
peer := &nbpeer.Peer{
@ -625,7 +626,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
assert.Error(t, err)
account.Peers[peer.ID] = peer
err = store.SaveAccount(account)
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
peer.Location.ConnectionIP = net.ParseIP("35.1.1.1")
@ -636,7 +637,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
err = store.SavePeerLocation(account.Id, account.Peers[peer.ID])
assert.NoError(t, err)
account, err = store.GetAccount(account.Id)
account, err = store.GetAccount(context.Background(), account.Id)
require.NoError(t, err)
actual := account.Peers[peer.ID].Location
@ -645,7 +646,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
func newStore(t *testing.T) *FileStore {
t.Helper()
store, err := NewFileStore(t.TempDir(), nil)
store, err := NewFileStore(context.Background(), t.TempDir(), nil)
if err != nil {
t.Errorf("failed creating a new store")
}

View File

@ -2,6 +2,7 @@ package geolocation
import (
"bytes"
"context"
"fmt"
"net"
"os"
@ -52,7 +53,7 @@ type Country struct {
CountryName string
}
func NewGeolocation(dataDir string) (*Geolocation, error) {
func NewGeolocation(ctx context.Context, dataDir string) (*Geolocation, error) {
if err := loadGeolocationDatabases(dataDir); err != nil {
return nil, fmt.Errorf("failed to load MaxMind databases: %v", err)
}
@ -68,7 +69,7 @@ func NewGeolocation(dataDir string) (*Geolocation, error) {
return nil, err
}
locationDB, err := NewSqliteStore(dataDir)
locationDB, err := NewSqliteStore(ctx, dataDir)
if err != nil {
return nil, err
}
@ -83,7 +84,7 @@ func NewGeolocation(dataDir string) (*Geolocation, error) {
stopCh: make(chan struct{}),
}
go geo.reloader()
go geo.reloader(ctx)
return geo, nil
}
@ -165,19 +166,19 @@ func (gl *Geolocation) Stop() error {
return nil
}
func (gl *Geolocation) reloader() {
func (gl *Geolocation) reloader(ctx context.Context) {
for {
select {
case <-gl.stopCh:
return
case <-time.After(gl.reloadCheckInterval):
if err := gl.locationDB.reload(); err != nil {
log.Errorf("geonames db reload failed: %s", err)
if err := gl.locationDB.reload(ctx); err != nil {
log.WithContext(ctx).Errorf("geonames db reload failed: %s", err)
}
newSha256sum1, err := calculateFileSHA256(gl.mmdbPath)
if err != nil {
log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
continue
}
if !bytes.Equal(gl.sha256sum, newSha256sum1) {
@ -186,30 +187,30 @@ func (gl *Geolocation) reloader() {
time.Sleep(50 * time.Millisecond)
newSha256sum2, err := calculateFileSHA256(gl.mmdbPath)
if err != nil {
log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
continue
}
if !bytes.Equal(newSha256sum1, newSha256sum2) {
log.Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath)
log.WithContext(ctx).Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath)
continue
}
err = gl.reload(newSha256sum2)
err = gl.reload(ctx, newSha256sum2)
if err != nil {
log.Errorf("mmdb reload failed: %s", err)
log.WithContext(ctx).Errorf("mmdb reload failed: %s", err)
}
} else {
log.Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.",
log.WithContext(ctx).Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.",
gl.mmdbPath, gl.reloadCheckInterval.Seconds())
}
}
}
}
func (gl *Geolocation) reload(newSha256sum []byte) error {
func (gl *Geolocation) reload(ctx context.Context, newSha256sum []byte) error {
gl.mux.Lock()
defer gl.mux.Unlock()
log.Infof("Reloading '%s'", gl.mmdbPath)
log.WithContext(ctx).Infof("Reloading '%s'", gl.mmdbPath)
err := gl.db.Close()
if err != nil {
@ -224,7 +225,7 @@ func (gl *Geolocation) reload(newSha256sum []byte) error {
gl.db = db
gl.sha256sum = newSha256sum
log.Infof("Successfully reloaded '%s'", gl.mmdbPath)
log.WithContext(ctx).Infof("Successfully reloaded '%s'", gl.mmdbPath)
return nil
}

View File

@ -2,6 +2,7 @@ package geolocation
import (
"bytes"
"context"
"fmt"
"path/filepath"
"runtime"
@ -50,10 +51,10 @@ type SqliteStore struct {
sha256sum []byte
}
func NewSqliteStore(dataDir string) (*SqliteStore, error) {
func NewSqliteStore(ctx context.Context, dataDir string) (*SqliteStore, error) {
file := filepath.Join(dataDir, GeoSqliteDBFile)
db, err := connectDB(file)
db, err := connectDB(ctx, file)
if err != nil {
return nil, err
}
@ -115,13 +116,13 @@ func (s *SqliteStore) GetCitiesByCountry(countryISOCode string) ([]City, error)
}
// reload attempts to reload the SqliteStore's database if the database file has changed.
func (s *SqliteStore) reload() error {
func (s *SqliteStore) reload(ctx context.Context) error {
s.mux.Lock()
defer s.mux.Unlock()
newSha256sum1, err := calculateFileSHA256(s.filePath)
if err != nil {
log.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
}
if !bytes.Equal(s.sha256sum, newSha256sum1) {
@ -136,11 +137,11 @@ func (s *SqliteStore) reload() error {
return fmt.Errorf("sha256 sum changed during reloading of '%s'", s.filePath)
}
log.Infof("Reloading '%s'", s.filePath)
log.WithContext(ctx).Infof("Reloading '%s'", s.filePath)
_ = s.close()
s.closed = true
newDb, err := connectDB(s.filePath)
newDb, err := connectDB(ctx, s.filePath)
if err != nil {
return err
}
@ -148,9 +149,9 @@ func (s *SqliteStore) reload() error {
s.closed = false
s.db = newDb
log.Infof("Successfully reloaded '%s'", s.filePath)
log.WithContext(ctx).Infof("Successfully reloaded '%s'", s.filePath)
} else {
log.Tracef("No changes in '%s', no need to reload", s.filePath)
log.WithContext(ctx).Tracef("No changes in '%s', no need to reload", s.filePath)
}
return nil
@ -168,10 +169,10 @@ func (s *SqliteStore) close() error {
}
// connectDB connects to an SQLite database and prepares it by setting up an in-memory database.
func connectDB(filePath string) (*gorm.DB, error) {
func connectDB(ctx context.Context, filePath string) (*gorm.DB, error) {
start := time.Now()
defer func() {
log.Debugf("took %v to setup geoname db", time.Since(start))
log.WithContext(ctx).Debugf("took %v to setup geoname db", time.Since(start))
}()
_, err := fileExists(filePath)

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"fmt"
"github.com/rs/xid"
@ -21,11 +22,11 @@ func (e *GroupLinkError) Error() string {
}
// GetGroup object of the peers
func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@ -48,11 +49,11 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*n
}
// GetAllGroups returns all groups in an account
func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@ -75,11 +76,11 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) (
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@ -108,11 +109,11 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*n
}
// SaveGroup object of the peers
func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
@ -150,11 +151,11 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *n
account.Groups[newGroup.ID] = newGroup
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
// the following snippet tracks the activity and stores the group events in the event store.
// It has to happen after all the operations have been successfully performed.
@ -165,16 +166,16 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *n
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else {
addedPeers = append(addedPeers, newGroup.Peers...)
am.StoreEvent(userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
}
for _, p := range addedPeers {
peer := account.Peers[p]
if peer == nil {
log.Errorf("peer %s not found under account %s while saving group", p, accountID)
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
continue
}
am.StoreEvent(userID, peer.ID, accountID, activity.GroupAddedToPeer,
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
@ -184,10 +185,10 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *n
for _, p := range removedPeers {
peer := account.Peers[p]
if peer == nil {
log.Errorf("peer %s not found under account %s while saving group", p, accountID)
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
continue
}
am.StoreEvent(userID, peer.ID, accountID, activity.GroupRemovedFromPeer,
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer,
map[string]any{
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
@ -213,11 +214,11 @@ func difference(a, b []string) []string {
}
// DeleteGroup object of the peers
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountId)
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountId)
defer unlock()
account, err := am.Store.GetAccount(accountId)
account, err := am.Store.GetAccount(ctx, accountId)
if err != nil {
return err
}
@ -315,23 +316,23 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
delete(account.Groups, groupID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
return nil
}
// ListGroups objects of the peers
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@ -345,11 +346,11 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group,
}
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
@ -371,21 +372,21 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string)
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
return nil
}
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
@ -399,13 +400,13 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID stri
for i, itemID := range group.Peers {
if itemID == peerID {
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
if err := am.Store.SaveAccount(account); err != nil {
if err := am.Store.SaveAccount(ctx, account); err != nil {
return err
}
}
}
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
return nil
}

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"errors"
"testing"
@ -26,7 +27,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
}
for _, group := range account.Groups {
group.Issued = nbgroup.GroupIssuedIntegration
err = am.SaveGroup(account.Id, groupAdminUserID, group)
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
if err != nil {
t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration)
}
@ -34,7 +35,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
for _, group := range account.Groups {
group.Issued = nbgroup.GroupIssuedJWT
err = am.SaveGroup(account.Id, groupAdminUserID, group)
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
if err != nil {
t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT)
}
@ -42,7 +43,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
for _, group := range account.Groups {
group.Issued = nbgroup.GroupIssuedAPI
group.ID = ""
err = am.SaveGroup(account.Id, groupAdminUserID, group)
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
if err == nil {
t.Errorf("should not create api group with the same name, %s", group.Name)
}
@ -104,7 +105,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
err = am.DeleteGroup(account.Id, groupAdminUserID, testCase.groupID)
err = am.DeleteGroup(context.Background(), account.Id, groupAdminUserID, testCase.groupID)
if err == nil {
t.Errorf("delete %s group successfully", testCase.groupID)
return
@ -225,7 +226,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
Id: "example user",
AutoGroups: []string{groupForUsers.ID},
}
account := newAccountWithId(accountID, groupAdminUserID, domain)
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
account.Routes[routeResource.ID] = routeResource
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
@ -233,18 +234,18 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
account.SetupKeys[setupKey.Id] = setupKey
account.Users[user.Id] = user
err := am.Store.SaveAccount(account)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err
}
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute2)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForNameServerGroups)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForPolicies)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForSetupKeys)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForUsers)
_ = am.SaveGroup(accountID, groupAdminUserID, groupForIntegration)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
return am.Store.GetAccount(account.Id)
return am.Store.GetAccount(context.Background(), account.Id)
}

View File

@ -11,12 +11,14 @@ import (
pb "github.com/golang/protobuf/proto" // nolint
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/netbird/management/server/posture"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@ -40,7 +42,7 @@ type GRPCServer struct {
}
// NewServer creates a new Management server
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
func NewServer(ctx context.Context, config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
@ -50,6 +52,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
jwtValidator, err = jwtclaims.NewJWTValidator(
ctx,
config.HttpConfig.AuthIssuer,
config.GetAuthAudiences(),
config.HttpConfig.AuthKeysLocation,
@ -59,7 +62,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
}
} else {
log.Debug("unable to use http config to create new jwt middleware")
log.WithContext(ctx).Debug("unable to use http config to create new jwt middleware")
}
if appMetrics != nil {
@ -126,47 +129,61 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequest()
}
realIP := getRealIP(srv.Context())
log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
ctx := srv.Context()
realIP := getRealIP(ctx)
syncReq := &proto.SyncRequest{}
peerKey, err := s.parseRequest(req, syncReq)
peerKey, err := s.parseRequest(ctx, req, syncReq)
if err != nil {
return err
}
//nolint
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail
accountID = "UNKNOWN"
}
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
if syncReq.GetMeta() == nil {
log.Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), extractPeerMeta(syncReq.GetMeta()), realIP)
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
if err != nil {
return mapError(err)
return mapError(ctx, err)
}
err = s.sendInitialSync(peerKey, peer, netMap, postureChecks, srv)
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv)
if err != nil {
log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
return err
}
updates := s.peersUpdateManager.CreateChannel(peer.ID)
updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID)
s.ephemeralManager.OnPeerConnected(peer)
s.ephemeralManager.OnPeerConnected(ctx, peer)
if s.config.TURNConfig.TimeBasedCredentials {
s.turnCredentialsManager.SetupRefresh(peer.ID)
s.turnCredentialsManager.SetupRefresh(ctx, peer.ID)
}
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
}
return s.handleUpdates(peerKey, peer, updates, srv)
return s.handleUpdates(ctx, peerKey, peer, updates, srv)
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
for {
select {
// condition when there are some updates
@ -176,21 +193,21 @@ func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updat
}
if !open {
log.Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(peer)
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(ctx, peer)
return nil
}
log.Debugf("received an update for peer %s", peerKey.String())
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if err := s.sendUpdate(peerKey, peer, update, srv); err != nil {
if err := s.sendUpdate(ctx, peerKey, peer, update, srv); err != nil {
return err
}
// condition when client <-> server connection has been terminated
case <-srv.Context().Done():
// happens when connection drops, e.g. client disconnects
log.Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(peer)
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(ctx, peer)
return srv.Context().Err()
}
}
@ -198,10 +215,10 @@ func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updat
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *GRPCServer) sendUpdate(peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
if err != nil {
s.cancelPeerRoutines(peer)
s.cancelPeerRoutines(ctx, peer)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.SendMsg(&proto.EncryptedMessage{
@ -209,37 +226,37 @@ func (s *GRPCServer) sendUpdate(peerKey wgtypes.Key, peer *nbpeer.Peer, update *
Body: encryptedResp,
})
if err != nil {
s.cancelPeerRoutines(peer)
s.cancelPeerRoutines(ctx, peer)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.Debugf("sent an update to peer %s", peerKey.String())
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
return nil
}
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(peer.ID)
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.turnCredentialsManager.CancelRefresh(peer.ID)
_ = s.accountManager.CancelPeerRoutines(peer)
s.ephemeralManager.OnPeerDisconnected(peer)
_ = s.accountManager.CancelPeerRoutines(ctx, peer)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
}
func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
if s.jwtValidator == nil {
return "", status.Error(codes.Internal, "no jwt validator set")
}
token, err := s.jwtValidator.ValidateAndParse(jwtToken)
token, err := s.jwtValidator.ValidateAndParse(ctx, jwtToken)
if err != nil {
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
}
claims := s.jwtClaimsExtractor.FromToken(token)
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
_, _, err = s.accountManager.GetAccountFromToken(claims)
_, _, err = s.accountManager.GetAccountFromToken(ctx, claims)
if err != nil {
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
}
if err := s.accountManager.CheckUserAccessByJWTGroups(claims); err != nil {
if err := s.accountManager.CheckUserAccessByJWTGroups(ctx, claims); err != nil {
return "", status.Errorf(codes.PermissionDenied, err.Error())
}
@ -247,7 +264,7 @@ func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
}
// maps internal internalStatus.Error to gRPC status.Error
func mapError(err error) error {
func mapError(ctx context.Context, err error) error {
if e, ok := internalStatus.FromError(err); ok {
switch e.Type() {
case internalStatus.PermissionDenied:
@ -263,11 +280,11 @@ func mapError(err error) error {
default:
}
}
log.Errorf("got an unhandled error: %s", err)
log.WithContext(ctx).Errorf("got an unhandled error: %s", err)
return status.Errorf(codes.Internal, "failed handling request")
}
func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
if meta == nil {
return nbpeer.PeerSystemMeta{}
}
@ -281,7 +298,7 @@ func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
for _, addr := range meta.GetNetworkAddresses() {
netAddr, err := netip.ParsePrefix(addr.GetNetIP())
if err != nil {
log.Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err)
log.WithContext(ctx).Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err)
continue
}
networkAddresses = append(networkAddresses, nbpeer.NetworkAddress{
@ -321,10 +338,10 @@ func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
}
}
func (s *GRPCServer) parseRequest(req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
log.Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
}
@ -351,22 +368,32 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
s.appMetrics.GRPCMetrics().CountLoginRequest()
}
realIP := getRealIP(ctx)
log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())
loginReq := &proto.LoginRequest{}
peerKey, err := s.parseRequest(req, loginReq)
peerKey, err := s.parseRequest(ctx, req, loginReq)
if err != nil {
return nil, err
}
//nolint
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail
accountID = "UNKNOWN"
}
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
if loginReq.GetMeta() == nil {
msg := status.Errorf(codes.FailedPrecondition,
"peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP)
log.Warn(msg)
log.WithContext(ctx).Warn(msg)
return nil, msg
}
userID, err := s.processJwtToken(loginReq, peerKey)
userID, err := s.processJwtToken(ctx, loginReq, peerKey)
if err != nil {
return nil, err
}
@ -376,33 +403,33 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
}
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(PeerLogin{
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, PeerLogin{
WireGuardPubKey: peerKey.String(),
SSHKey: string(sshKey),
Meta: extractPeerMeta(loginReq.GetMeta()),
Meta: extractPeerMeta(ctx, loginReq.GetMeta()),
UserID: userID,
SetupKey: loginReq.GetSetupKey(),
ConnectionIP: realIP,
})
if err != nil {
log.Warnf("failed logging in peer %s: %s", peerKey, err)
return nil, mapError(err)
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
return nil, mapError(ctx, err)
}
// if the login request contains setup key then it is a registration request
if loginReq.GetSetupKey() != "" {
s.ephemeralManager.OnPeerDisconnected(peer)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
}
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
Checks: toProtocolChecks(postureChecks),
Checks: toProtocolChecks(ctx, postureChecks),
}
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
if err != nil {
log.Warnf("failed encrypting peer %s message", peer.ID)
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
return nil, status.Errorf(codes.Internal, "failed logging in peer")
}
@ -417,16 +444,16 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
//
// The user ID can be empty if the token is not provided, which is acceptable if the peer is already
// registered or if it uses a setup key to register.
func (s *GRPCServer) processJwtToken(loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
userID := ""
if loginReq.GetJwtToken() != "" {
var err error
for i := 0; i < 3; i++ {
userID, err = s.validateToken(loginReq.GetJwtToken())
userID, err = s.validateToken(ctx, loginReq.GetJwtToken())
if err == nil {
break
}
log.Warnf("failed validating JWT token sent from peer %s with error %v. "+
log.WithContext(ctx).Warnf("failed validating JWT token sent from peer %s with error %v. "+
"Trying again as it may be due to the IdP cache issue", peerKey.String(), err)
time.Sleep(200 * time.Millisecond)
}
@ -520,7 +547,7 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee
return remotePeers
}
func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
@ -551,7 +578,7 @@ func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCred
FirewallRules: firewallRules,
FirewallRulesIsEmpty: len(firewallRules) == 0,
},
Checks: toProtocolChecks(checks),
Checks: toProtocolChecks(ctx, checks),
}
}
@ -561,7 +588,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
}
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
// make secret time based TURN credentials optional
var turnCredentials *TURNCredentials
if s.config.TURNConfig.TimeBasedCredentials {
@ -570,7 +597,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net
} else {
turnCredentials = nil
}
plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {
@ -583,7 +610,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net
})
if err != nil {
log.Errorf("failed sending SyncResponse %v", err)
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
return status.Errorf(codes.Internal, "error handling request")
}
@ -597,14 +624,14 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)
log.Warn(errMSG)
log.WithContext(ctx).Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{})
if err != nil {
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
log.Warn(errMSG)
log.WithContext(ctx).Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
@ -645,18 +672,18 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
// GetPKCEAuthorizationFlow returns a pkce authorization flow information
// This is used for initiating an Oauth 2 pkce authorization grant flow
// which will be used by our clients to Login
func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
if err != nil {
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey)
log.Warn(errMSG)
log.WithContext(ctx).Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{})
if err != nil {
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
log.Warn(errMSG)
log.WithContext(ctx).Warn(errMSG)
return nil, status.Error(codes.InvalidArgument, errMSG)
}
@ -692,10 +719,10 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.Encr
// peer's under the same account of any updates.
func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
realIP := getRealIP(ctx)
log.Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
syncMetaReq := &proto.SyncMetaRequest{}
peerKey, err := s.parseRequest(req, syncMetaReq)
peerKey, err := s.parseRequest(ctx, req, syncMetaReq)
if err != nil {
return nil, err
}
@ -703,20 +730,20 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
if syncMetaReq.GetMeta() == nil {
msg := status.Errorf(codes.FailedPrecondition,
"peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
log.Warn(msg)
log.WithContext(ctx).Warn(msg)
return nil, msg
}
err = s.accountManager.SyncPeerMeta(peerKey.String(), extractPeerMeta(syncMetaReq.GetMeta()))
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()))
if err != nil {
return nil, mapError(err)
return nil, mapError(ctx, err)
}
return &proto.Empty{}, nil
}
// toProtocolChecks converts posture checks to protocol checks.
func toProtocolChecks(postureChecks []*posture.Checks) []*proto.Checks {
func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks {
protoChecks := make([]*proto.Checks, 0, len(postureChecks))
for _, postureCheck := range postureChecks {
protoChecks = append(protoChecks, toProtocolCheck(postureCheck))

View File

@ -35,34 +35,34 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
if !(user.HasAdminPower() || user.IsServiceUser) {
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
return
}
resp := toAccountResponse(account)
util.WriteJSONObject(w, []*api.Account{resp})
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
}
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
_, user, err := h.accountManager.GetAccountFromToken(claims)
_, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
accountID := vars["accountId"]
if len(accountID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
return
}
@ -96,15 +96,15 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
}
updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings)
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
resp := toAccountResponse(updatedAccount)
util.WriteJSONObject(w, &resp)
util.WriteJSONObject(r.Context(), w, &resp)
}
// DeleteAccount is a HTTP DELETE handler to delete an account
@ -118,17 +118,17 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
vars := mux.Vars(r)
targetAccountID := vars["accountId"]
if len(targetAccountID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid account ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid account ID"), w)
return
}
err := h.accountManager.DeleteAccount(targetAccountID, claims.UserId)
err := h.accountManager.DeleteAccount(r.Context(), targetAccountID, claims.UserId)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
util.WriteJSONObject(r.Context(), w, emptyObject{})
}
func toAccountResponse(account *server.Account) *api.Account {

View File

@ -2,6 +2,7 @@ package http
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@ -22,10 +23,10 @@ import (
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
return &AccountsHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return account, admin, nil
},
UpdateAccountSettingsFunc: func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")

View File

@ -32,16 +32,16 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetDNSSettings returns the DNS settings for the account
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
log.Error(err)
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
dnsSettings, err := h.accountManager.GetDNSSettings(account.Id, user.Id)
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -49,15 +49,15 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
DisabledManagementGroups: dnsSettings.DisabledManagementGroups,
}
util.WriteJSONObject(w, apiDNSSettings)
util.WriteJSONObject(r.Context(), w, apiDNSSettings)
}
// UpdateDNSSettings handles update to DNS settings of an account
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -72,9 +72,9 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
DisabledManagementGroups: req.DisabledManagementGroups,
}
err = h.accountManager.SaveDNSSettings(account.Id, user.Id, updateDNSSettings)
err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -82,5 +82,5 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
DisabledManagementGroups: updateDNSSettings.DisabledManagementGroups,
}
util.WriteJSONObject(w, &resp)
util.WriteJSONObject(r.Context(), w, &resp)
}

View File

@ -2,6 +2,7 @@ package http
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@ -42,16 +43,16 @@ var testingDNSSettingsAccount = &server.Account{
func initDNSSettingsTestData() *DNSSettingsHandler {
return &DNSSettingsHandler{
accountManager: &mock_server.MockAccountManager{
GetDNSSettingsFunc: func(accountID string, userID string) (*server.DNSSettings, error) {
GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) {
return &testingDNSSettingsAccount.DNSSettings, nil
},
SaveDNSSettingsFunc: func(accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
if dnsSettingsToSave != nil {
return nil
}
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
},
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil
},
},

View File

@ -1,6 +1,7 @@
package http
import (
"context"
"fmt"
"net/http"
@ -33,16 +34,16 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev
// GetAllEvents list of the given account
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
log.Error(err)
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountEvents, err := h.accountManager.GetEvents(account.Id, user.Id)
accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
events := make([]*api.Event, len(accountEvents))
@ -50,20 +51,20 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
events[i] = toEventResponse(e)
}
err = h.fillEventsWithUserInfo(events, account.Id, user.Id)
err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, events)
util.WriteJSONObject(r.Context(), w, events)
}
func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, userId string) error {
func (h *EventsHandler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error {
// build email, name maps based on users
userInfos, err := h.accountManager.GetUsersFromAccount(accountId, userId)
userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId)
if err != nil {
log.Errorf("failed to get users from account: %s", err)
log.WithContext(ctx).Errorf("failed to get users from account: %s", err)
return err
}
@ -80,7 +81,7 @@ func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, u
if event.InitiatorEmail == "" {
event.InitiatorEmail, ok = emails[event.InitiatorId]
if !ok {
log.Warnf("failed to resolve email for initiator: %s", event.InitiatorId)
log.WithContext(ctx).Warnf("failed to resolve email for initiator: %s", event.InitiatorId)
}
}

View File

@ -1,6 +1,7 @@
package http
import (
"context"
"encoding/json"
"io"
"net/http"
@ -22,13 +23,13 @@ import (
func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler {
return &EventsHandler{
accountManager: &mock_server.MockAccountManager{
GetEventsFunc: func(accountID, userID string) ([]*activity.Event, error) {
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
if accountID == account {
return events, nil
}
return []*activity.Event{}, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
@ -37,7 +38,7 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
},
}, user, nil
},
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
return make([]*server.UserInfo, 0), nil
},
},

View File

@ -1,6 +1,7 @@
package http
import (
"context"
"encoding/json"
"io"
"net/http"
@ -35,13 +36,13 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
err = util.CopyFileContents(geonamesDBPath, path.Join(tempDir, geolocation.GeoSqliteDBFile))
assert.NoError(t, err)
geo, err := geolocation.NewGeolocation(tempDir)
geo, err := geolocation.NewGeolocation(context.Background(), tempDir)
assert.NoError(t, err)
t.Cleanup(func() { _ = geo.Stop() })
return &GeolocationsHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,

View File

@ -40,19 +40,19 @@ func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geoloca
// GetAllCountries retrieves a list of all countries
func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
if l.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
return
}
allCountries, err := l.geolocationManager.GetAllCountries()
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -60,32 +60,32 @@ func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Req
for _, country := range allCountries {
countries = append(countries, toCountryResponse(country))
}
util.WriteJSONObject(w, countries)
util.WriteJSONObject(r.Context(), w, countries)
}
// GetCitiesByCountry retrieves a list of cities based on the given country code
func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
countryCode := vars["country"]
if !countryCodeRegex.MatchString(countryCode) {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid country code"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid country code"), w)
return
}
if l.geolocationManager == nil {
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
return
}
allCities, err := l.geolocationManager.GetCitiesByCountry(countryCode)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -93,12 +93,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
for _, city := range allCities {
cities = append(cities, toCityResponse(city))
}
util.WriteJSONObject(w, cities)
util.WriteJSONObject(r.Context(), w, cities)
}
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
claims := l.claimsExtractor.FromRequestContext(r)
_, user, err := l.accountManager.GetAccountFromToken(claims)
_, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
return err
}

View File

@ -35,16 +35,16 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr
// GetAllGroups list for the account
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
log.Error(err)
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
groups, err := h.accountManager.GetAllGroups(account.Id, user.Id)
groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -53,42 +53,42 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
groupsResponse = append(groupsResponse, toGroupResponse(account, group))
}
util.WriteJSONObject(w, groupsResponse)
util.WriteJSONObject(r.Context(), w, groupsResponse)
}
// UpdateGroup handles update to a group identified by a given ID
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
groupID, ok := vars["groupId"]
if !ok {
util.WriteError(status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
return
}
if len(groupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "group ID can't be empty"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group ID can't be empty"), w)
return
}
eg, ok := account.Groups[groupID]
if !ok {
util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
return
}
allGroup, err := account.GetGroupAll()
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
if allGroup.ID == groupID {
util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
return
}
@ -100,7 +100,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
}
if req.Name == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
return
}
@ -118,21 +118,21 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
IntegrationReference: eg.IntegrationReference,
}
if err := h.accountManager.SaveGroup(account.Id, user.Id, &group); err != nil {
log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
util.WriteError(err, w)
if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil {
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, toGroupResponse(account, &group))
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
}
// CreateGroup handles group creation request
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -144,7 +144,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
}
if req.Name == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
return
}
@ -160,62 +160,62 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
Issued: nbgroup.GroupIssuedAPI,
}
err = h.accountManager.SaveGroup(account.Id, user.Id, &group)
err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, toGroupResponse(account, &group))
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
}
// DeleteGroup handles group deletion request
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
aID := account.Id
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
allGroup, err := account.GetGroupAll()
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
if allGroup.ID == groupID {
util.WriteError(status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
return
}
err = h.accountManager.DeleteGroup(aID, user.Id, groupID)
err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID)
if err != nil {
_, ok := err.(*server.GroupLinkError)
if ok {
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
return
}
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
util.WriteJSONObject(r.Context(), w, emptyObject{})
}
// GetGroup returns a group
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -223,19 +223,19 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
case http.MethodGet:
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
group, err := h.accountManager.GetGroup(account.Id, groupID, user.Id)
group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, toGroupResponse(account, group))
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group))
default:
util.WriteError(status.Errorf(status.NotFound, "HTTP method not found"), w)
util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w)
return
}
}

View File

@ -2,6 +2,7 @@ package http
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@ -32,13 +33,13 @@ var TestPeers = map[string]*nbpeer.Peer{
func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
return &GroupsHandler{
accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(accountID, userID string, group *nbgroup.Group) error {
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
if !strings.HasPrefix(group.ID, "id-") {
group.ID = "id-was-set"
}
return nil
},
GetGroupFunc: func(_, groupID, _ string) (*nbgroup.Group, error) {
GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) {
if groupID != "idofthegroup" {
return nil, status.Errorf(status.NotFound, "not found")
}
@ -55,7 +56,7 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
Issued: nbgroup.GroupIssuedAPI,
}, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
@ -70,7 +71,7 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
},
}, user, nil
},
DeleteGroupFunc: func(accountID, userId, groupID string) error {
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
if groupID == "linked-grp" {
return &server.GroupLinkError{
Resource: "something",

View File

@ -9,6 +9,7 @@ import (
"github.com/rs/cors"
"github.com/netbirdio/management-integrations/integrations"
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/middleware"
@ -57,6 +58,11 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
corsMiddleware := cors.AllowAll()
claimsExtractor = jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
)
acMiddleware := middleware.NewAccessControl(
authCfg.Audience,
authCfg.UserIDClaim,

View File

@ -1,6 +1,7 @@
package middleware
import (
"context"
"net/http"
"regexp"
@ -15,7 +16,7 @@ import (
)
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
type GetUser func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
type AccessControl struct {
@ -46,15 +47,15 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
claims := a.claimsExtract.FromRequestContext(r)
user, err := a.getUser(claims)
user, err := a.getUser(r.Context(), claims)
if err != nil {
log.Errorf("failed to get user from claims: %s", err)
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err)
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w)
return
}
if user.IsBlocked() {
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
return
}
@ -63,12 +64,12 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
if tokenPathRegexp.MatchString(r.URL.Path) {
log.Debugf("valid Path")
log.WithContext(r.Context()).Debugf("valid Path")
h.ServeHTTP(w, r)
return
}
util.WriteError(status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w)
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w)
return
}
}

View File

@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@ -19,16 +20,16 @@ import (
)
// GetAccountFromPATFunc function
type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
// ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
// MarkPATUsedFunc function
type MarkPATUsedFunc func(token string) error
type MarkPATUsedFunc func(ctx context.Context, token string) error
// CheckUserAccessByJWTGroupsFunc function
type CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
@ -85,23 +86,27 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
case "bearer":
err := m.checkJWTFromRequest(w, r, auth)
if err != nil {
log.Errorf("Error when validating JWT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
log.WithContext(r.Context()).Errorf("Error when validating JWT claims: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
case "token":
err := m.checkPATFromRequest(w, r, auth)
if err != nil {
log.Debugf("Error when validating PAT claims: %s", err.Error())
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
log.WithContext(r.Context()).Debugf("Error when validating PAT claims: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
default:
util.WriteError(status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
return
}
claims := m.claimsExtractor.FromRequestContext(r)
//nolint
ctx := context.WithValue(r.Context(), nbContext.UserIDKey, claims.UserId)
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, claims.AccountId)
h.ServeHTTP(w, r.WithContext(ctx))
})
}
@ -114,7 +119,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
return fmt.Errorf("Error extracting token: %w", err)
}
validatedToken, err := m.validateAndParseToken(token)
validatedToken, err := m.validateAndParseToken(r.Context(), token)
if err != nil {
return err
}
@ -123,7 +128,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
return nil
}
if err := m.verifyUserAccess(validatedToken); err != nil {
if err := m.verifyUserAccess(r.Context(), validatedToken); err != nil {
return err
}
@ -138,9 +143,9 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
// verifyUserAccess checks if a user, based on a validated JWT token,
// is allowed access, particularly in cases where the admin enabled JWT
// group propagation and designated certain groups with access permissions.
func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error {
func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *jwt.Token) error {
authClaims := m.claimsExtractor.FromToken(validatedToken)
return m.checkUserAccessByJWTGroups(authClaims)
return m.checkUserAccessByJWTGroups(ctx, authClaims)
}
// CheckPATFromRequest checks if the PAT is valid
@ -152,7 +157,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
return fmt.Errorf("Error extracting token: %w", err)
}
account, user, pat, err := m.getAccountFromPAT(token)
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
if err != nil {
return fmt.Errorf("invalid Token: %w", err)
}
@ -160,7 +165,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
return fmt.Errorf("token expired")
}
err = m.markPATUsed(pat.ID)
err = m.markPATUsed(r.Context(), pat.ID)
if err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package middleware
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
@ -15,15 +16,16 @@ import (
)
const (
audience = "audience"
userIDClaim = "userIDClaim"
accountID = "accountID"
domain = "domain"
userID = "userID"
tokenID = "tokenID"
PAT = "nbp_PAT"
JWT = "JWT"
wrongToken = "wrongToken"
audience = "audience"
userIDClaim = "userIDClaim"
accountID = "accountID"
domain = "domain"
domainCategory = "domainCategory"
userID = "userID"
tokenID = "tokenID"
PAT = "nbp_PAT"
JWT = "JWT"
wrongToken = "wrongToken"
)
var testAccount = &server.Account{
@ -47,14 +49,14 @@ var testAccount = &server.Account{
},
}
func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
if token == PAT {
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
}
return nil, nil, nil, fmt.Errorf("PAT invalid")
}
func mockValidateAndParseToken(token string) (*jwt.Token, error) {
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
if token == JWT {
return &jwt.Token{
Claims: jwt.MapClaims{
@ -67,14 +69,14 @@ func mockValidateAndParseToken(token string) (*jwt.Token, error) {
return nil, fmt.Errorf("JWT invalid")
}
func mockMarkPATUsed(token string) error {
func mockMarkPATUsed(_ context.Context, token string) error {
if token == tokenID {
return nil
}
return fmt.Errorf("Should never get reached")
}
func mockCheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error {
func mockCheckUserAccessByJWTGroups(_ context.Context, claims jwtclaims.AuthorizationClaims) error {
if testAccount.Id != claims.AccountId {
return fmt.Errorf("account with id %s does not exist", claims.AccountId)
}

View File

@ -56,7 +56,7 @@ func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r *
for bypassPath := range bypassPaths {
matched, err := path.Match(bypassPath, requestPath)
if err != nil {
log.Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err)
log.WithContext(r.Context()).Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err)
continue
}
if matched {

View File

@ -36,16 +36,16 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetAllNameservers returns the list of nameserver groups for the account
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
log.Error(err)
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
nsGroups, err := h.accountManager.ListNameServerGroups(account.Id, user.Id)
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -54,15 +54,15 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
apiNameservers = append(apiNameservers, toNameserverGroupResponse(r))
}
util.WriteJSONObject(w, apiNameservers)
util.WriteJSONObject(r.Context(), w, apiNameservers)
}
// CreateNameserverGroup handles nameserver group creation request
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -75,33 +75,33 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
nsList, err := toServerNSList(req.Nameservers)
if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
return
}
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled)
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
resp := toNameserverGroupResponse(nsGroup)
util.WriteJSONObject(w, &resp)
util.WriteJSONObject(r.Context(), w, &resp)
}
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
@ -114,7 +114,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
nsList, err := toServerNSList(req.Nameservers)
if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
return
}
@ -130,66 +130,66 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
SearchDomainsEnabled: req.SearchDomainsEnabled,
}
err = h.accountManager.SaveNameServerGroup(account.Id, user.Id, updatedNSGroup)
err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
resp := toNameserverGroupResponse(updatedNSGroup)
util.WriteJSONObject(w, &resp)
util.WriteJSONObject(r.Context(), w, &resp)
}
// DeleteNameserverGroup handles nameserver group deletion request
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID, user.Id)
err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
util.WriteJSONObject(r.Context(), w, emptyObject{})
}
// GetNameserverGroup handles a nameserver group Get request identified by ID
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
log.Error(err)
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, user.Id, nsGroupID)
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
resp := toNameserverGroupResponse(nsGroup)
util.WriteJSONObject(w, &resp)
util.WriteJSONObject(r.Context(), w, &resp)
}
func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) {

View File

@ -2,6 +2,7 @@ package http
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@ -61,13 +62,13 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{
func initNameserversTestData() *NameserversHandler {
return &NameserversHandler{
accountManager: &mock_server.MockAccountManager{
GetNameServerGroupFunc: func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
if nsGroupID == existingNSGroupID {
return baseExistingNSGroup.Copy(), nil
}
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
},
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) {
CreateNameServerGroupFunc: func(_ context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) {
return &nbdns.NameServerGroup{
ID: existingNSGroupID,
Name: name,
@ -80,16 +81,16 @@ func initNameserversTestData() *NameserversHandler {
SearchDomainsEnabled: searchDomains,
}, nil
},
DeleteNameServerGroupFunc: func(accountID, nsGroupID, _ string) error {
DeleteNameServerGroupFunc: func(_ context.Context, accountID, nsGroupID, _ string) error {
return nil
},
SaveNameServerGroupFunc: func(accountID, _ string, nsGroupToSave *nbdns.NameServerGroup) error {
SaveNameServerGroupFunc: func(_ context.Context, accountID, _ string, nsGroupToSave *nbdns.NameServerGroup) error {
if nsGroupToSave.ID == existingNSGroupID {
return nil
}
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
},
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingNSAccount, testingAccount.Users["test_user"], nil
},
},

View File

@ -34,22 +34,22 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
userID := vars["userId"]
if len(userID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
pats, err := h.accountManager.GetAllPATs(account.Id, user.Id, userID)
pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -58,53 +58,53 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
patResponse = append(patResponse, toPATResponse(pat))
}
util.WriteJSONObject(w, patResponse)
util.WriteJSONObject(r.Context(), w, patResponse)
}
// GetToken is HTTP GET handler that returns a personal access token for the given user
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
tokenID := vars["tokenId"]
if len(tokenID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid token ID"), w)
return
}
pat, err := h.accountManager.GetPAT(account.Id, user.Id, targetUserID, tokenID)
pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, toPATResponse(pat))
util.WriteJSONObject(r.Context(), w, toPATResponse(pat))
}
// CreateToken is HTTP POST handler that creates a personal access token for the given user
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
@ -115,44 +115,44 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
return
}
pat, err := h.accountManager.CreatePAT(account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn)
pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, toPATGeneratedResponse(pat))
util.WriteJSONObject(r.Context(), w, toPATGeneratedResponse(pat))
}
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
tokenID := vars["tokenId"]
if len(tokenID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid token ID"), w)
return
}
err = h.accountManager.DeletePAT(account.Id, user.Id, targetUserID, tokenID)
err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
util.WriteJSONObject(r.Context(), w, emptyObject{})
}
func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken {

View File

@ -2,6 +2,7 @@ package http
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@ -63,7 +64,7 @@ var testAccount = &server.Account{
func initPATTestData() *PATHandler {
return &PATHandler{
accountManager: &mock_server.MockAccountManager{
CreatePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
@ -76,10 +77,10 @@ func initPATTestData() *PATHandler {
}, nil
},
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testAccount, testAccount.Users[existingUserID], nil
},
DeletePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if accountID != existingAccountID {
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
@ -91,7 +92,7 @@ func initPATTestData() *PATHandler {
}
return nil
},
GetPATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
@ -103,7 +104,7 @@ func initPATTestData() *PATHandler {
}
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
},
GetAllPATsFunc: func(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}

View File

@ -1,6 +1,7 @@
package http
import (
"context"
"encoding/json"
"fmt"
"net/http"
@ -47,16 +48,16 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error)
return peerToReturn, nil
}
func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(account.Id, peerID, userID)
func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID)
if err != nil {
util.WriteError(err, w)
util.WriteError(ctx, err, w)
return
}
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(err, w)
util.WriteError(ctx, err, w)
return
}
dnsDomain := h.accountManager.GetDNSDomain()
@ -65,19 +66,19 @@ func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w
validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
log.Errorf("failed to list appreoved peers: %v", err)
util.WriteError(fmt.Errorf("internal error"), w)
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
return
}
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers)
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID]
util.WriteJSONObject(w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
}
func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@ -99,9 +100,9 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe
}
}
peer, err := h.accountManager.UpdatePeer(account.Id, user.Id, update)
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update)
if err != nil {
util.WriteError(err, w)
util.WriteError(ctx, err, w)
return
}
dnsDomain := h.accountManager.GetDNSDomain()
@ -110,75 +111,75 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe
validPeers, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
log.Errorf("failed to list appreoved peers: %v", err)
util.WriteError(fmt.Errorf("internal error"), w)
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
return
}
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers)
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID]
util.WriteJSONObject(w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid))
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid))
}
func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w http.ResponseWriter) {
err := h.accountManager.DeletePeer(accountID, peerID, userID)
func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID)
if err != nil {
log.Errorf("failed to delete peer: %v", err)
util.WriteError(err, w)
log.WithContext(ctx).Errorf("failed to delete peer: %v", err)
util.WriteError(ctx, err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
util.WriteJSONObject(ctx, w, emptyObject{})
}
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return
}
switch r.Method {
case http.MethodDelete:
h.deletePeer(account.Id, user.Id, peerID, w)
h.deletePeer(r.Context(), account.Id, user.Id, peerID, w)
return
case http.MethodPut:
h.updatePeer(account, user, peerID, w, r)
h.updatePeer(r.Context(), account, user, peerID, w, r)
return
case http.MethodGet:
h.getPeer(account, peerID, user.Id, w)
h.getPeer(r.Context(), account, peerID, user.Id, w)
return
default:
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
}
}
// GetAllPeers returns a list of all peers associated with a provided account
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
return
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
peers, err := h.accountManager.GetPeers(account.Id, user.Id)
peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -188,34 +189,34 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
accessiblePeerNumbers, _ := h.accessiblePeersNumber(account, peer.ID)
accessiblePeerNumbers, _ := h.accessiblePeersNumber(r.Context(), account, peer.ID)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers))
}
validPeersMap, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
log.Errorf("failed to list appreoved peers: %v", err)
util.WriteError(fmt.Errorf("internal error"), w)
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return
}
h.setApprovalRequiredFlag(respBody, validPeersMap)
util.WriteJSONObject(w, respBody)
util.WriteJSONObject(r.Context(), w, respBody)
}
func (h *PeersHandler) accessiblePeersNumber(account *server.Account, peerID string) (int, error) {
func (h *PeersHandler) accessiblePeersNumber(ctx context.Context, account *server.Account, peerID string) (int, error) {
validatedPeersMap, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
return 0, err
}
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validatedPeersMap)
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validatedPeersMap)
return len(netMap.Peers) + len(netMap.OfflinePeers), nil
}

View File

@ -2,6 +2,7 @@ package http
import (
"bytes"
"context"
"encoding/json"
"io"
"net"
@ -29,7 +30,7 @@ const noUpdateChannelTestPeerID = "no-update-channel"
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return &PeersHandler{
accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
var p *nbpeer.Peer
for _, peer := range peers {
if update.ID == peer.ID {
@ -42,7 +43,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
p.Name = update.Name
return p, nil
},
GetPeerFunc: func(accountID, peerID, userID string) (*nbpeer.Peer, error) {
GetPeerFunc: func(_ context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
var p *nbpeer.Peer
for _, peer := range peers {
if peerID == peer.ID {
@ -52,13 +53,13 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
}
return p, nil
},
GetPeersFunc: func(accountID, userID string) ([]*nbpeer.Peer, error) {
GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
return peers, nil
},
GetDNSDomainFunc: func() string {
return "netbird.selfhosted"
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,

View File

@ -35,15 +35,15 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *
// GetAllPolicies list for the account
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
accountPolicies, err := h.accountManager.ListPolicies(account.Id, user.Id)
accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -51,28 +51,28 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
for _, policy := range accountPolicies {
resp := toPolicyResponse(account, policy)
if len(resp.Rules) == 0 {
util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w)
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
}
policies = append(policies, resp)
}
util.WriteJSONObject(w, policies)
util.WriteJSONObject(r.Context(), w, policies)
}
// UpdatePolicy handles update to a policy identified by a given ID
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
}
@ -84,7 +84,7 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
}
}
if policyIdx < 0 {
util.WriteError(status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
return
}
@ -94,9 +94,9 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
// CreatePolicy handles policy creation request
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -118,12 +118,12 @@ func (h *Policies) savePolicy(
}
if req.Name == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w)
return
}
if len(req.Rules) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "policy rules shouldn't be empty"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy rules shouldn't be empty"), w)
return
}
@ -137,31 +137,31 @@ func (h *Policies) savePolicy(
Enabled: req.Enabled,
Description: req.Description,
}
for _, r := range req.Rules {
for _, rule := range req.Rules {
pr := server.PolicyRule{
ID: policyID, //TODO: when policy can contain multiple rules, need refactor
Name: r.Name,
Destinations: groupMinimumsToStrings(account, r.Destinations),
Sources: groupMinimumsToStrings(account, r.Sources),
Bidirectional: r.Bidirectional,
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
Name: rule.Name,
Destinations: groupMinimumsToStrings(account, rule.Destinations),
Sources: groupMinimumsToStrings(account, rule.Sources),
Bidirectional: rule.Bidirectional,
}
pr.Enabled = r.Enabled
if r.Description != nil {
pr.Description = *r.Description
pr.Enabled = rule.Enabled
if rule.Description != nil {
pr.Description = *rule.Description
}
switch r.Action {
switch rule.Action {
case api.PolicyRuleUpdateActionAccept:
pr.Action = server.PolicyTrafficActionAccept
case api.PolicyRuleUpdateActionDrop:
pr.Action = server.PolicyTrafficActionDrop
default:
util.WriteError(status.Errorf(status.InvalidArgument, "unknown action type"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown action type"), w)
return
}
switch r.Protocol {
switch rule.Protocol {
case api.PolicyRuleUpdateProtocolAll:
pr.Protocol = server.PolicyRuleProtocolALL
case api.PolicyRuleUpdateProtocolTcp:
@ -171,14 +171,14 @@ func (h *Policies) savePolicy(
case api.PolicyRuleUpdateProtocolIcmp:
pr.Protocol = server.PolicyRuleProtocolICMP
default:
util.WriteError(status.Errorf(status.InvalidArgument, "unknown protocol type: %v", r.Protocol), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w)
return
}
if r.Ports != nil && len(*r.Ports) != 0 {
for _, v := range *r.Ports {
if rule.Ports != nil && len(*rule.Ports) != 0 {
for _, v := range *rule.Ports {
if port, err := strconv.Atoi(v); err != nil || port < 1 || port > 65535 {
util.WriteError(status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w)
return
}
pr.Ports = append(pr.Ports, v)
@ -189,16 +189,16 @@ func (h *Policies) savePolicy(
switch pr.Protocol {
case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP:
if len(pr.Ports) != 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
return
}
if !pr.Bidirectional {
util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
return
}
case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP:
if !pr.Bidirectional && len(pr.Ports) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
return
}
}
@ -210,26 +210,26 @@ func (h *Policies) savePolicy(
policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks)
}
if err := h.accountManager.SavePolicy(account.Id, user.Id, &policy); err != nil {
util.WriteError(err, w)
if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toPolicyResponse(account, &policy)
if len(resp.Rules) == 0 {
util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w)
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
}
util.WriteJSONObject(w, resp)
util.WriteJSONObject(r.Context(), w, resp)
}
// DeletePolicy handles policy deletion request
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
aID := account.Id
@ -237,24 +237,24 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
}
if err = h.accountManager.DeletePolicy(aID, policyID, user.Id); err != nil {
util.WriteError(err, w)
if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
util.WriteJSONObject(r.Context(), w, emptyObject{})
}
// GetPolicy handles a group Get request identified by ID
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -263,25 +263,25 @@ func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
}
policy, err := h.accountManager.GetPolicy(account.Id, policyID, user.Id)
policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
resp := toPolicyResponse(account, policy)
if len(resp.Rules) == 0 {
util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w)
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
}
util.WriteJSONObject(w, resp)
util.WriteJSONObject(r.Context(), w, resp)
default:
util.WriteError(status.Errorf(status.NotFound, "method not found"), w)
util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w)
}
}

View File

@ -2,6 +2,7 @@ package http
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@ -30,21 +31,21 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
}
return &Policies{
accountManager: &mock_server.MockAccountManager{
GetPolicyFunc: func(_, policyID, _ string) (*server.Policy, error) {
GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) {
policy, ok := testPolicies[policyID]
if !ok {
return nil, status.Errorf(status.NotFound, "policy not found")
}
return policy, nil
},
SavePolicyFunc: func(_, _ string, policy *server.Policy) error {
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error {
if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set"
policy.Rules[0].ID = "id-was-set"
}
return nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,

View File

@ -37,15 +37,15 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa
// GetAllPostureChecks list for the account
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(claims)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
accountPostureChecks, err := p.accountManager.ListPostureChecks(account.Id, user.Id)
accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -54,22 +54,22 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt
postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
}
util.WriteJSONObject(w, postureChecks)
util.WriteJSONObject(r.Context(), w, postureChecks)
}
// UpdatePostureCheck handles update to a posture check identified by a given ID
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(claims)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
return
}
@ -81,7 +81,7 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
}
}
if postureChecksIdx < 0 {
util.WriteError(status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
return
}
@ -91,9 +91,9 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
// CreatePostureCheck handles posture check creation request
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(claims)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -103,50 +103,50 @@ func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http
// GetPostureCheck handles a posture check Get request identified by ID
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(claims)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
return
}
postureChecks, err := p.accountManager.GetPostureChecks(account.Id, postureChecksID, user.Id)
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, postureChecks.ToAPIResponse())
util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse())
}
// DeletePostureCheck handles posture check deletion request
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(claims)
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
return
}
if err = p.accountManager.DeletePostureChecks(account.Id, postureChecksID, user.Id); err != nil {
util.WriteError(err, w)
if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
util.WriteJSONObject(r.Context(), w, emptyObject{})
}
// savePostureChecks handles posture checks create and update
@ -169,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks(
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
if p.geolocationManager == nil {
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
return
}
@ -177,14 +177,14 @@ func (p *PostureChecksHandler) savePostureChecks(
postureChecks, err := posture.NewChecksFromAPIPostureCheckUpdate(req, postureChecksID)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
if err := p.accountManager.SavePostureChecks(account.Id, user.Id, postureChecks); err != nil {
util.WriteError(err, w)
if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, postureChecks.ToAPIResponse())
util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse())
}

View File

@ -2,6 +2,7 @@ package http
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@ -33,14 +34,14 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
return &PostureChecksHandler{
accountManager: &mock_server.MockAccountManager{
GetPostureChecksFunc: func(accountID, postureChecksID, userID string) (*posture.Checks, error) {
GetPostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
p, ok := testPostureChecks[postureChecksID]
if !ok {
return nil, status.Errorf(status.NotFound, "posture checks not found")
}
return p, nil
},
SavePostureChecksFunc: func(accountID, userID string, postureChecks *posture.Checks) error {
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error {
postureChecks.ID = "postureCheck"
testPostureChecks[postureChecks.ID] = postureChecks
@ -50,7 +51,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
return nil
},
DeletePostureChecksFunc: func(accountID, postureChecksID, userID string) error {
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
_, ok := testPostureChecks[postureChecksID]
if !ok {
return status.Errorf(status.NotFound, "posture checks not found")
@ -59,14 +60,14 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
return nil
},
ListPostureChecksFunc: func(accountID, userID string) ([]*posture.Checks, error) {
ListPostureChecksFunc: func(_ context.Context, accountID, userID string) ([]*posture.Checks, error) {
accountPostureChecks := make([]*posture.Checks, len(testPostureChecks))
for _, p := range testPostureChecks {
accountPostureChecks = append(accountPostureChecks, p)
}
return accountPostureChecks, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,

View File

@ -43,36 +43,36 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro
// GetAllRoutes returns the list of routes for the account
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
routes, err := h.accountManager.ListRoutes(account.Id, user.Id)
routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
apiRoutes := make([]*api.Route, 0)
for _, r := range routes {
route, err := toRouteResponse(r)
for _, route := range routes {
route, err := toRouteResponse(route)
if err != nil {
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
return
}
apiRoutes = append(apiRoutes, route)
}
util.WriteJSONObject(w, apiRoutes)
util.WriteJSONObject(r.Context(), w, apiRoutes)
}
// CreateRoute handles route creation request
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -84,7 +84,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
}
if err := h.validateRoute(req); err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -94,7 +94,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
if req.Domains != nil {
d, err := validateDomains(*req.Domains)
if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
return
}
domains = d
@ -102,7 +102,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
} else if req.Network != nil {
networkType, newPrefix, err = route.ParseNetwork(*req.Network)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
}
@ -120,24 +120,24 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
// Do not allow non-Linux peers
if peer := account.GetPeer(peerId); peer != nil {
if peer.Meta.GoOS != "linux" {
util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
return
}
}
newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
routes, err := toRouteResponse(newRoute)
if err != nil {
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
return
}
util.WriteJSONObject(w, routes)
util.WriteJSONObject(r.Context(), w, routes)
}
func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) error {
@ -168,22 +168,22 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
// UpdateRoute handles update to a route identified by a given ID
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
routeID := vars["routeId"]
if len(routeID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
}
_, err = h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id)
_, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -195,7 +195,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
}
if err := h.validateRoute(req); err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -207,7 +207,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
// do not allow non Linux peers
if peer := account.GetPeer(peerID); peer != nil {
if peer.Meta.GoOS != "linux" {
util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
return
}
}
@ -226,7 +226,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
if req.Domains != nil {
d, err := validateDomains(*req.Domains)
if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
return
}
newRoute.Domains = d
@ -234,7 +234,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
} else if req.Network != nil {
newRoute.NetworkType, newRoute.Network, err = route.ParseNetwork(*req.Network)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
}
@ -247,73 +247,73 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
newRoute.PeerGroups = *req.PeerGroups
}
err = h.accountManager.SaveRoute(account.Id, user.Id, newRoute)
err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
routes, err := toRouteResponse(newRoute)
if err != nil {
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
return
}
util.WriteJSONObject(w, routes)
util.WriteJSONObject(r.Context(), w, routes)
}
// DeleteRoute handles route deletion request
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
}
err = h.accountManager.DeleteRoute(account.Id, route.ID(routeID), user.Id)
err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
util.WriteJSONObject(r.Context(), w, emptyObject{})
}
// GetRoute handles a route Get request identified by ID
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
}
foundRoute, err := h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id)
foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
if err != nil {
util.WriteError(status.Errorf(status.NotFound, "route not found"), w)
util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
return
}
routes, err := toRouteResponse(foundRoute)
if err != nil {
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
return
}
util.WriteJSONObject(w, routes)
util.WriteJSONObject(r.Context(), w, routes)
}
func toRouteResponse(serverRoute *route.Route) (*api.Route, error) {

View File

@ -2,6 +2,7 @@ package http
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@ -89,7 +90,7 @@ var testingAccount = &server.Account{
func initRoutesTestData() *RoutesHandler {
return &RoutesHandler{
accountManager: &mock_server.MockAccountManager{
GetRouteFunc: func(_ string, routeID route.ID, _ string) (*route.Route, error) {
GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) {
if routeID == existingRouteID {
return baseExistingRoute, nil
}
@ -104,7 +105,7 @@ func initRoutesTestData() *RoutesHandler {
}
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
},
CreateRouteFunc: func(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
if peerID == notFoundPeerID {
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}
@ -126,19 +127,19 @@ func initRoutesTestData() *RoutesHandler {
KeepRoute: keepRoute,
}, nil
},
SaveRouteFunc: func(_, _ string, r *route.Route) error {
SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error {
if r.Peer == notFoundPeerID {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer)
}
return nil
},
DeleteRouteFunc: func(_ string, routeID route.ID, _ string) error {
DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error {
if routeID != existingRouteID {
return status.Errorf(status.NotFound, "Peer with ID %s not found", routeID)
}
return nil
},
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return testingAccount, testingAccount.Users["test_user"], nil
},
},

View File

@ -1,6 +1,7 @@
package http
import (
"context"
"encoding/json"
"net/http"
"time"
@ -34,9 +35,9 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg)
// CreateSetupKey is a POST requests that creates a new SetupKey
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -48,13 +49,13 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
}
if req.Name == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w)
return
}
if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable ||
server.SetupKeyType(req.Type) == server.SetupKeyOneOff) {
util.WriteError(status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w)
return
}
@ -63,7 +64,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
day := time.Hour * 24
year := day * 365
if expiresIn < day || expiresIn > year {
util.WriteError(status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w)
return
}
@ -75,54 +76,54 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
if req.Ephemeral != nil {
ephemeral = *req.Ephemeral
}
setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
req.AutoGroups, req.UsageLimit, user.Id, ephemeral)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
writeSuccess(w, setupKey)
writeSuccess(r.Context(), w, setupKey)
}
// GetSetupKey is a GET request to get a SetupKey by ID
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
keyID := vars["keyId"]
if len(keyID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w)
return
}
key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID)
key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
writeSuccess(w, key)
writeSuccess(r.Context(), w, key)
}
// UpdateSetupKey is a PUT request to update server.SetupKey
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
keyID := vars["keyId"]
if len(keyID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w)
return
}
@ -134,12 +135,12 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
}
if req.Name == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
return
}
if req.AutoGroups == nil {
util.WriteError(status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
return
}
@ -149,26 +150,26 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
newKey.Name = req.Name
newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey, user.Id)
newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
writeSuccess(w, newKey)
writeSuccess(r.Context(), w, newKey)
}
// GetAllSetupKeys is a GET request that returns a list of SetupKey
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id)
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -177,15 +178,15 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques
apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
}
util.WriteJSONObject(w, apiSetupKeys)
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
}
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
err := json.NewEncoder(w).Encode(toResponseBody(key))
if err != nil {
util.WriteError(err, w)
util.WriteError(ctx, err, w)
return
}
}

View File

@ -2,6 +2,7 @@ package http
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@ -33,7 +34,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
) *SetupKeysHandler {
return &SetupKeysHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return &server.Account{
Id: testAccountID,
Domain: "hotmail.com",
@ -49,7 +50,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
},
}, user, nil
},
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
_ int, _ string, ephemeral bool,
) (*server.SetupKey, error) {
if keyName == newKey.Name || typ != newKey.Type {
@ -59,7 +60,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
}
return nil, fmt.Errorf("failed creating setup key")
},
GetSetupKeyFunc: func(accountID, userID, keyID string) (*server.SetupKey, error) {
GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*server.SetupKey, error) {
switch keyID {
case defaultKey.Id:
return defaultKey, nil
@ -70,14 +71,14 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
}
},
SaveSetupKeyFunc: func(accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) {
SaveSetupKeyFunc: func(_ context.Context, accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) {
if key.Id == updatedSetupKey.Id {
return updatedSetupKey, nil
}
return nil, status.Errorf(status.NotFound, "key %s not found", key.Id)
},
ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) {
ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) {
return []*server.SetupKey{defaultKey}, nil
},
},

View File

@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
userID := vars["userId"]
if len(userID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
existingUser, ok := account.Users[userID]
if !ok {
util.WriteError(status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w)
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w)
return
}
@ -74,11 +74,11 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
userRole := server.StrRoleToUserRole(req.Role)
if userRole == server.UserRoleUnknown {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user role"), w)
return
}
newUser, err := h.accountManager.SaveUser(account.Id, user.Id, &server.User{
newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{
Id: userID,
Role: userRole,
AutoGroups: req.AutoGroups,
@ -88,10 +88,10 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
})
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId))
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
}
// DeleteUser is a DELETE request to delete a user
@ -102,26 +102,26 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
err = h.accountManager.DeleteUser(account.Id, user.Id, targetUserID)
err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
util.WriteJSONObject(r.Context(), w, emptyObject{})
}
// CreateUser creates a User in the system with a status "invited" (effectively this is a user invite).
@ -132,9 +132,9 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
@ -146,7 +146,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
}
if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown {
util.WriteError(status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w)
return
}
@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
name = *req.Name
}
newUser, err := h.accountManager.CreateUser(account.Id, user.Id, &server.UserInfo{
newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{
Email: email,
Name: name,
Role: req.Role,
@ -169,10 +169,10 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
Issued: server.UserIssuedAPI,
})
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId))
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
}
// GetAllUsers returns a list of users of the account this user belongs to.
@ -184,42 +184,42 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id)
data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
serviceUser := r.URL.Query().Get("service_user")
users := make([]*api.User, 0)
for _, r := range data {
if r.NonDeletable {
for _, d := range data {
if d.NonDeletable {
continue
}
if serviceUser == "" {
users = append(users, toUserResponse(r, claims.UserId))
users = append(users, toUserResponse(d, claims.UserId))
continue
}
includeServiceUser, err := strconv.ParseBool(serviceUser)
log.Debugf("Should include service user: %v", includeServiceUser)
log.WithContext(r.Context()).Debugf("Should include service user: %v", includeServiceUser)
if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
return
}
if includeServiceUser == r.IsServiceUser {
users = append(users, toUserResponse(r, claims.UserId))
if includeServiceUser == d.IsServiceUser {
users = append(users, toUserResponse(d, claims.UserId))
}
}
util.WriteJSONObject(w, users)
util.WriteJSONObject(r.Context(), w, users)
}
// InviteUser resend invitations to users who haven't activated their accounts,
@ -231,26 +231,26 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
err = h.accountManager.InviteUser(account.Id, user.Id, targetUserID)
err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID)
if err != nil {
util.WriteError(err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
util.WriteJSONObject(r.Context(), w, emptyObject{})
}
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {

View File

@ -2,6 +2,7 @@ package http
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
@ -63,10 +64,10 @@ var usersTestAccount = &server.Account{
func initUsersTestData() *UsersHandler {
return &UsersHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
return usersTestAccount, usersTestAccount.Users[claims.UserId], nil
},
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0)
for _, v := range usersTestAccount.Users {
users = append(users, &server.UserInfo{
@ -81,13 +82,13 @@ func initUsersTestData() *UsersHandler {
}
return users, nil
},
CreateUserFunc: func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) {
CreateUserFunc: func(_ context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) {
if userID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
}
return key, nil
},
DeleteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
DeleteUserFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) error {
if targetUserID == notFoundUserID {
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
}
@ -96,7 +97,7 @@ func initUsersTestData() *UsersHandler {
}
return nil
},
SaveUserFunc: func(accountID, userID string, update *server.User) (*server.UserInfo, error) {
SaveUserFunc: func(_ context.Context, accountID, userID string, update *server.User) (*server.UserInfo, error) {
if update.Id == notFoundUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id)
}
@ -111,7 +112,7 @@ func initUsersTestData() *UsersHandler {
}
return info, nil
},
InviteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
InviteUserFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) error {
if initiatorUserID != existingUserID {
return status.Errorf(status.NotFound, "user with ID %s does not exists", initiatorUserID)
}

View File

@ -1,6 +1,7 @@
package util
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -19,12 +20,12 @@ type ErrorResponse struct {
}
// WriteJSONObject simply writes object to the HTTP response in JSON format
func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
func WriteJSONObject(ctx context.Context, w http.ResponseWriter, obj interface{}) {
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
w.WriteHeader(http.StatusOK)
err := json.NewEncoder(w).Encode(obj)
if err != nil {
WriteError(err, w)
WriteError(ctx, err, w)
return
}
}
@ -76,8 +77,8 @@ func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
// WriteError converts an error to an JSON error response.
// If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise
func WriteError(err error, w http.ResponseWriter) {
log.Errorf("got a handler error: %s", err.Error())
func WriteError(ctx context.Context, err error, w http.ResponseWriter) {
log.WithContext(ctx).Errorf("got a handler error: %s", err.Error())
errStatus, ok := status.FromError(err)
httpStatus := http.StatusInternalServerError
msg := "internal server error"
@ -106,7 +107,7 @@ func WriteError(err error, w http.ResponseWriter) {
msg = strings.ToLower(err.Error())
} else {
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error())
log.Error(unhandledMSG)
log.WithContext(ctx).Error(unhandledMSG)
}
WriteErrorResponse(msg, httpStatus, w)

View File

@ -183,7 +183,7 @@ func (c *Auth0Credentials) jwtStillValid() bool {
}
// requestJWTToken performs request to get jwt token
func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) {
func (c *Auth0Credentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
var res *http.Response
reqURL := c.clientConfig.AuthIssuer + "/oauth/token"
@ -200,7 +200,7 @@ func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) {
req.Header.Add("content-type", "application/json")
log.Debug("requesting new jwt token for idp manager")
log.WithContext(ctx).Debug("requesting new jwt token for idp manager")
res, err = c.httpClient.Do(req)
if err != nil {
@ -247,7 +247,7 @@ func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTTo
}
// Authenticate retrieves access token to use the Auth0 Management API
func (c *Auth0Credentials) Authenticate() (JWTToken, error) {
func (c *Auth0Credentials) Authenticate(ctx context.Context) (JWTToken, error) {
c.mux.Lock()
defer c.mux.Unlock()
@ -260,14 +260,14 @@ func (c *Auth0Credentials) Authenticate() (JWTToken, error) {
return c.jwtToken, nil
}
res, err := c.requestJWTToken()
res, err := c.requestJWTToken(ctx)
if err != nil {
return c.jwtToken, err
}
defer func() {
err = res.Body.Close()
if err != nil {
log.Errorf("error while closing get jwt token response body: %v", err)
log.WithContext(ctx).Errorf("error while closing get jwt token response body: %v", err)
}
}()
@ -301,8 +301,8 @@ func requestByUserIDURL(authIssuer, userID string) string {
}
// GetAccount returns all the users for a given profile. Calls Auth0 API.
func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
jwtToken, err := am.credentials.Authenticate()
func (am *Auth0Manager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return nil, err
}
@ -353,7 +353,7 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
return nil, err
}
log.Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch))
log.WithContext(ctx).Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch))
err = res.Body.Close()
if err != nil {
@ -365,7 +365,7 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
}
if len(batch) == 0 || len(batch) < resultsPerPage {
log.Debugf("finished loading users for accountID %s", accountID)
log.WithContext(ctx).Debugf("finished loading users for accountID %s", accountID)
return list, nil
}
}
@ -374,8 +374,8 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
}
// GetUserDataByID requests user data from auth0 via ID
func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
jwtToken, err := am.credentials.Authenticate()
func (am *Auth0Manager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return nil, err
}
@ -414,7 +414,7 @@ func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata)
defer func() {
err = res.Body.Close()
if err != nil {
log.Errorf("error while closing update user app metadata response body: %v", err)
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
}
}()
@ -426,9 +426,9 @@ func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata)
}
// UpdateUserAppMetadata updates user app metadata based on userId and metadata map
func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
func (am *Auth0Manager) UpdateUserAppMetadata(ctx context.Context, userID string, appMetadata AppMetadata) error {
jwtToken, err := am.credentials.Authenticate()
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return err
}
@ -449,7 +449,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
log.Debugf("updating IdP metadata for user %s", userID)
log.WithContext(ctx).Debugf("updating IdP metadata for user %s", userID)
res, err := am.httpClient.Do(req)
if err != nil {
@ -466,7 +466,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
defer func() {
err = res.Body.Close()
if err != nil {
log.Errorf("error while closing update user app metadata response body: %v", err)
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
}
}()
@ -530,9 +530,9 @@ func buildUserExportRequest() (string, error) {
}
func (am *Auth0Manager) createRequest(
method string, endpoint string, body io.Reader,
ctx context.Context, method string, endpoint string, body io.Reader,
) (*http.Request, error) {
jwtToken, err := am.credentials.Authenticate()
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return nil, err
}
@ -548,8 +548,8 @@ func (am *Auth0Manager) createRequest(
return req, nil
}
func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) {
req, err := am.createRequest("POST", endpoint, strings.NewReader(payloadStr))
func (am *Auth0Manager) createPostRequest(ctx context.Context, endpoint string, payloadStr string) (*http.Request, error) {
req, err := am.createRequest(ctx, "POST", endpoint, strings.NewReader(payloadStr))
if err != nil {
return nil, err
}
@ -560,20 +560,20 @@ func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
func (am *Auth0Manager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
payloadString, err := buildUserExportRequest()
if err != nil {
return nil, err
}
exportJobReq, err := am.createPostRequest("/api/v2/jobs/users-exports", payloadString)
exportJobReq, err := am.createPostRequest(ctx, "/api/v2/jobs/users-exports", payloadString)
if err != nil {
return nil, err
}
jobResp, err := am.httpClient.Do(exportJobReq)
if err != nil {
log.Debugf("Couldn't get job response %v", err)
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
@ -583,7 +583,7 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
defer func() {
err = jobResp.Body.Close()
if err != nil {
log.Errorf("error while closing update user app metadata response body: %v", err)
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
}
}()
if jobResp.StatusCode != 200 {
@ -597,13 +597,13 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
body, err := io.ReadAll(jobResp.Body)
if err != nil {
log.Debugf("Couldn't read export job response; %v", err)
log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
return nil, err
}
err = am.helper.Unmarshal(body, &exportJobResp)
if err != nil {
log.Debugf("Couldn't unmarshal export job response; %v", err)
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
return nil, err
}
@ -614,16 +614,16 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
}
log.Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
log.WithContext(ctx).Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
done, downloadLink, err := am.checkExportJobStatus(exportJobResp.ID)
done, downloadLink, err := am.checkExportJobStatus(ctx, exportJobResp.ID)
if err != nil {
log.Debugf("Failed at getting status checks from exportJob; %v", err)
log.WithContext(ctx).Debugf("Failed at getting status checks from exportJob; %v", err)
return nil, err
}
if done {
return am.downloadProfileExport(downloadLink)
return am.downloadProfileExport(ctx, downloadLink)
}
return nil, fmt.Errorf("failed extracting user profiles from auth0")
@ -632,13 +632,13 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
// GetUserByEmail searches users with a given email. If no users have been found, this function returns an empty list.
// This function can return multiple users. This is due to the Auth0 internals - there could be multiple users with
// the same email but different connections that are considered as separate accounts (e.g., Google and username/password).
func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) {
jwtToken, err := am.credentials.Authenticate()
func (am *Auth0Manager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return nil, err
}
reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + url.QueryEscape(email)
body, err := doGetReq(am.httpClient, reqURL, jwtToken.AccessToken)
body, err := doGetReq(ctx, am.httpClient, reqURL, jwtToken.AccessToken)
if err != nil {
return nil, err
}
@ -651,7 +651,7 @@ func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) {
err = am.helper.Unmarshal(body, &userResp)
if err != nil {
log.Debugf("Couldn't unmarshal export job response; %v", err)
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
return nil, err
}
@ -659,13 +659,13 @@ func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) {
}
// CreateUser creates a new user in Auth0 Idp and sends an invite
func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
func (am *Auth0Manager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
payloadString, err := buildCreateUserRequestPayload(email, name, accountID, invitedByEmail)
if err != nil {
return nil, err
}
req, err := am.createPostRequest("/api/v2/users", payloadString)
req, err := am.createPostRequest(ctx, "/api/v2/users", payloadString)
if err != nil {
return nil, err
}
@ -676,7 +676,7 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string
resp, err := am.httpClient.Do(req)
if err != nil {
log.Debugf("Couldn't get job response %v", err)
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
@ -686,7 +686,7 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string
defer func() {
err = resp.Body.Close()
if err != nil {
log.Errorf("error while closing create user response body: %v", err)
log.WithContext(ctx).Errorf("error while closing create user response body: %v", err)
}
}()
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
@ -700,13 +700,13 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Debugf("Couldn't read export job response; %v", err)
log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
return nil, err
}
err = am.helper.Unmarshal(body, &createResp)
if err != nil {
log.Debugf("Couldn't unmarshal export job response; %v", err)
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
return nil, err
}
@ -714,14 +714,14 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string
return nil, fmt.Errorf("couldn't create user: response %v", resp)
}
log.Debugf("created user %s in account %s", createResp.ID, accountID)
log.WithContext(ctx).Debugf("created user %s in account %s", createResp.ID, accountID)
return &createResp, nil
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (am *Auth0Manager) InviteUserByID(userID string) error {
func (am *Auth0Manager) InviteUserByID(ctx context.Context, userID string) error {
userVerificationReq := userVerificationJobRequest{
UserID: userID,
}
@ -731,14 +731,14 @@ func (am *Auth0Manager) InviteUserByID(userID string) error {
return err
}
req, err := am.createPostRequest("/api/v2/jobs/verification-email", string(payload))
req, err := am.createPostRequest(ctx, "/api/v2/jobs/verification-email", string(payload))
if err != nil {
return err
}
resp, err := am.httpClient.Do(req)
if err != nil {
log.Debugf("Couldn't get job response %v", err)
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
@ -748,7 +748,7 @@ func (am *Auth0Manager) InviteUserByID(userID string) error {
defer func() {
err = resp.Body.Close()
if err != nil {
log.Errorf("error while closing invite user response body: %v", err)
log.WithContext(ctx).Errorf("error while closing invite user response body: %v", err)
}
}()
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
@ -762,15 +762,15 @@ func (am *Auth0Manager) InviteUserByID(userID string) error {
}
// DeleteUser from Auth0
func (am *Auth0Manager) DeleteUser(userID string) error {
req, err := am.createRequest(http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil)
func (am *Auth0Manager) DeleteUser(ctx context.Context, userID string) error {
req, err := am.createRequest(ctx, http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil)
if err != nil {
return err
}
resp, err := am.httpClient.Do(req)
if err != nil {
log.Debugf("execute delete request: %v", err)
log.WithContext(ctx).Debugf("execute delete request: %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
@ -780,7 +780,7 @@ func (am *Auth0Manager) DeleteUser(userID string) error {
defer func() {
err = resp.Body.Close()
if err != nil {
log.Errorf("close delete request body: %v", err)
log.WithContext(ctx).Errorf("close delete request body: %v", err)
}
}()
if resp.StatusCode != 204 {
@ -795,20 +795,20 @@ func (am *Auth0Manager) DeleteUser(userID string) error {
// GetAllConnections returns detailed list of all connections filtered by given params.
// Note this method is not part of the IDP Manager interface as this is Auth0 specific.
func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, error) {
func (am *Auth0Manager) GetAllConnections(ctx context.Context, strategy []string) ([]Connection, error) {
var connections []Connection
q := make(url.Values)
q.Set("strategy", strings.Join(strategy, ","))
req, err := am.createRequest(http.MethodGet, "/api/v2/connections?"+q.Encode(), nil)
req, err := am.createRequest(ctx, http.MethodGet, "/api/v2/connections?"+q.Encode(), nil)
if err != nil {
return connections, err
}
resp, err := am.httpClient.Do(req)
if err != nil {
log.Debugf("execute get connections request: %v", err)
log.WithContext(ctx).Debugf("execute get connections request: %v", err)
if am.appMetrics != nil {
am.appMetrics.IDPMetrics().CountRequestError()
}
@ -818,7 +818,7 @@ func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, erro
defer func() {
err = resp.Body.Close()
if err != nil {
log.Errorf("close get connections request body: %v", err)
log.WithContext(ctx).Errorf("close get connections request body: %v", err)
}
}()
if resp.StatusCode != 200 {
@ -830,13 +830,13 @@ func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, erro
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Debugf("Couldn't read get connections response; %v", err)
log.WithContext(ctx).Debugf("Couldn't read get connections response; %v", err)
return connections, err
}
err = am.helper.Unmarshal(body, &connections)
if err != nil {
log.Debugf("Couldn't unmarshal get connection response; %v", err)
log.WithContext(ctx).Debugf("Couldn't unmarshal get connection response; %v", err)
return connections, err
}
@ -845,23 +845,23 @@ func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, erro
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
// If the status is "completed", then return the downloadLink
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
func (am *Auth0Manager) checkExportJobStatus(ctx context.Context, jobID string) (bool, string, error) {
ctx, cancel := context.WithTimeout(ctx, 90*time.Second)
defer cancel()
retry := time.NewTicker(10 * time.Second)
for {
select {
case <-ctx.Done():
log.Debugf("Export job status stopped...\n")
log.WithContext(ctx).Debugf("Export job status stopped...\n")
return false, "", ctx.Err()
case <-retry.C:
jwtToken, err := am.credentials.Authenticate()
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return false, "", err
}
statusURL := am.authIssuer + "/api/v2/jobs/" + jobID
body, err := doGetReq(am.httpClient, statusURL, jwtToken.AccessToken)
body, err := doGetReq(ctx, am.httpClient, statusURL, jwtToken.AccessToken)
if err != nil {
return false, "", err
}
@ -872,7 +872,7 @@ func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error)
return false, "", err
}
log.Debugf("current export job status is %v", status.Status)
log.WithContext(ctx).Debugf("current export job status is %v", status.Status)
if status.Status != "completed" {
continue
@ -884,8 +884,8 @@ func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error)
}
// downloadProfileExport downloads user profiles from auth0 batch job
func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*UserData, error) {
body, err := doGetReq(am.httpClient, location, "")
func (am *Auth0Manager) downloadProfileExport(ctx context.Context, location string) (map[string][]*UserData, error) {
body, err := doGetReq(ctx, am.httpClient, location, "")
if err != nil {
return nil, err
}
@ -927,7 +927,7 @@ func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*Us
}
// Boilerplate implementation for Get Requests.
func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error) {
func doGetReq(ctx context.Context, client ManagerHTTPClient, url, accessToken string) ([]byte, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
@ -945,7 +945,7 @@ func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error)
defer func() {
err = res.Body.Close()
if err != nil {
log.Errorf("error while closing body for url %s: %v", url, err)
log.WithContext(ctx).Errorf("error while closing body for url %s: %v", url, err)
}
}()
body, err := io.ReadAll(res.Body)

View File

@ -1,6 +1,7 @@
package idp
import (
"context"
"encoding/json"
"fmt"
"io"
@ -60,7 +61,7 @@ type mockAuth0Credentials struct {
err error
}
func (mc *mockAuth0Credentials) Authenticate() (JWTToken, error) {
func (mc *mockAuth0Credentials) Authenticate(_ context.Context) (JWTToken, error) {
return mc.jwtToken, mc.err
}
@ -126,7 +127,7 @@ func TestAuth0_RequestJWTToken(t *testing.T) {
helper: testCase.helper,
}
res, err := creds.requestJWTToken()
res, err := creds.requestJWTToken(context.Background())
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
@ -295,7 +296,7 @@ func TestAuth0_Authenticate(t *testing.T) {
creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate()
_, err := creds.Authenticate(context.Background())
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
@ -417,7 +418,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
helper: testCase.helper,
}
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
err := manager.UpdateUserAppMetadata(context.Background(), "1", testCase.appMetadata)
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
assert.Equal(t, testCase.expectedReqBody, jwtReqClient.reqBody, "request body should match")

View File

@ -116,7 +116,7 @@ func (ac *AuthentikCredentials) jwtStillValid() bool {
}
// requestJWTToken performs request to get jwt token.
func (ac *AuthentikCredentials) requestJWTToken() (*http.Response, error) {
func (ac *AuthentikCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
data := url.Values{}
data.Set("client_id", ac.clientConfig.ClientID)
data.Set("username", ac.clientConfig.Username)
@ -131,7 +131,7 @@ func (ac *AuthentikCredentials) requestJWTToken() (*http.Response, error) {
}
req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.Debug("requesting new jwt token for authentik idp manager")
log.WithContext(ctx).Debug("requesting new jwt token for authentik idp manager")
resp, err := ac.httpClient.Do(req)
if err != nil {
@ -183,7 +183,7 @@ func (ac *AuthentikCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (
}
// Authenticate retrieves access token to use the authentik management API.
func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) {
func (ac *AuthentikCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
ac.mux.Lock()
defer ac.mux.Unlock()
@ -197,7 +197,7 @@ func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) {
return ac.jwtToken, nil
}
resp, err := ac.requestJWTToken()
resp, err := ac.requestJWTToken(ctx)
if err != nil {
return ac.jwtToken, err
}
@ -214,13 +214,13 @@ func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) {
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (am *AuthentikManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
func (am *AuthentikManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
// GetUserDataByID requests user data from authentik via ID.
func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
ctx, err := am.authenticationContext()
func (am *AuthentikManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
ctx, err := am.authenticationContext(ctx)
if err != nil {
return nil, err
}
@ -254,8 +254,8 @@ func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetada
}
// GetAccount returns all the users for a given profile.
func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
users, err := am.getAllUsers()
func (am *AuthentikManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
users, err := am.getAllUsers(ctx)
if err != nil {
return nil, err
}
@ -274,8 +274,8 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) {
users, err := am.getAllUsers()
func (am *AuthentikManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
users, err := am.getAllUsers(ctx)
if err != nil {
return nil, err
}
@ -291,12 +291,12 @@ func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) {
}
// getAllUsers returns all users in a Authentik account.
func (am *AuthentikManager) getAllUsers() ([]*UserData, error) {
func (am *AuthentikManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
users := make([]*UserData, 0)
page := int32(1)
for {
ctx, err := am.authenticationContext()
ctx, err := am.authenticationContext(ctx)
if err != nil {
return nil, err
}
@ -329,14 +329,14 @@ func (am *AuthentikManager) getAllUsers() ([]*UserData, error) {
}
// CreateUser creates a new user in authentik Idp and sends an invitation.
func (am *AuthentikManager) CreateUser(_, _, _, _ string) (*UserData, error) {
func (am *AuthentikManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) {
ctx, err := am.authenticationContext()
func (am *AuthentikManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
ctx, err := am.authenticationContext(ctx)
if err != nil {
return nil, err
}
@ -368,13 +368,13 @@ func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) {
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (am *AuthentikManager) InviteUserByID(_ string) error {
func (am *AuthentikManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Authentik
func (am *AuthentikManager) DeleteUser(userID string) error {
ctx, err := am.authenticationContext()
func (am *AuthentikManager) DeleteUser(ctx context.Context, userID string) error {
ctx, err := am.authenticationContext(ctx)
if err != nil {
return err
}
@ -404,8 +404,8 @@ func (am *AuthentikManager) DeleteUser(userID string) error {
return nil
}
func (am *AuthentikManager) authenticationContext() (context.Context, error) {
jwtToken, err := am.credentials.Authenticate()
func (am *AuthentikManager) authenticationContext(ctx context.Context) (context.Context, error) {
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return nil, err
}

View File

@ -1,6 +1,7 @@
package idp
import (
"context"
"fmt"
"io"
"strings"
@ -138,7 +139,7 @@ func TestAuthentikRequestJWTToken(t *testing.T) {
helper: testCase.helper,
}
resp, err := creds.requestJWTToken()
resp, err := creds.requestJWTToken(context.Background())
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
@ -304,7 +305,7 @@ func TestAuthentikAuthenticate(t *testing.T) {
}
creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate()
_, err := creds.Authenticate(context.Background())
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")

View File

@ -1,6 +1,7 @@
package idp
import (
"context"
"fmt"
"io"
"net/http"
@ -110,7 +111,7 @@ func (ac *AzureCredentials) jwtStillValid() bool {
}
// requestJWTToken performs request to get jwt token.
func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) {
func (ac *AzureCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
data := url.Values{}
data.Set("client_id", ac.clientConfig.ClientID)
data.Set("client_secret", ac.clientConfig.ClientSecret)
@ -132,7 +133,7 @@ func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) {
}
req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.Debug("requesting new jwt token for azure idp manager")
log.WithContext(ctx).Debug("requesting new jwt token for azure idp manager")
resp, err := ac.httpClient.Do(req)
if err != nil {
@ -184,7 +185,7 @@ func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTT
}
// Authenticate retrieves access token to use the azure Management API.
func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
func (ac *AzureCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
ac.mux.Lock()
defer ac.mux.Unlock()
@ -198,7 +199,7 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
return ac.jwtToken, nil
}
resp, err := ac.requestJWTToken()
resp, err := ac.requestJWTToken(ctx)
if err != nil {
return ac.jwtToken, err
}
@ -215,16 +216,16 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
}
// CreateUser creates a new user in azure AD Idp.
func (am *AzureManager) CreateUser(_, _, _, _ string) (*UserData, error) {
func (am *AzureManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserDataByID requests user data from keycloak via ID.
func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
func (am *AzureManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
q := url.Values{}
q.Add("$select", profileFields)
body, err := am.get("users/"+userID, q)
body, err := am.get(ctx, "users/"+userID, q)
if err != nil {
return nil, err
}
@ -247,11 +248,11 @@ func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata)
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
func (am *AzureManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
q := url.Values{}
q.Add("$select", profileFields)
body, err := am.get("users/"+email, q)
body, err := am.get(ctx, "users/"+email, q)
if err != nil {
return nil, err
}
@ -273,8 +274,8 @@ func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
}
// GetAccount returns all the users for a given profile.
func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
users, err := am.getAllUsers()
func (am *AzureManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
users, err := am.getAllUsers(ctx)
if err != nil {
return nil, err
}
@ -293,8 +294,8 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
users, err := am.getAllUsers()
func (am *AzureManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
users, err := am.getAllUsers(ctx)
if err != nil {
return nil, err
}
@ -310,19 +311,19 @@ func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
}
// UpdateUserAppMetadata updates user app metadata based on userID.
func (am *AzureManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
func (am *AzureManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (am *AzureManager) InviteUserByID(_ string) error {
func (am *AzureManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Azure.
func (am *AzureManager) DeleteUser(userID string) error {
jwtToken, err := am.credentials.Authenticate()
func (am *AzureManager) DeleteUser(ctx context.Context, userID string) error {
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return err
}
@ -335,7 +336,7 @@ func (am *AzureManager) DeleteUser(userID string) error {
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
log.Debugf("delete idp user %s", userID)
log.WithContext(ctx).Debugf("delete idp user %s", userID)
resp, err := am.httpClient.Do(req)
if err != nil {
@ -358,7 +359,7 @@ func (am *AzureManager) DeleteUser(userID string) error {
}
// getAllUsers returns all users in an Azure AD account.
func (am *AzureManager) getAllUsers() ([]*UserData, error) {
func (am *AzureManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
users := make([]*UserData, 0)
q := url.Values{}
@ -366,7 +367,7 @@ func (am *AzureManager) getAllUsers() ([]*UserData, error) {
q.Add("$top", "500")
for nextLink := "users"; nextLink != ""; {
body, err := am.get(nextLink, q)
body, err := am.get(ctx, nextLink, q)
if err != nil {
return nil, err
}
@ -391,8 +392,8 @@ func (am *AzureManager) getAllUsers() ([]*UserData, error) {
}
// get perform Get requests.
func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate()
func (am *AzureManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
jwtToken, err := am.credentials.Authenticate(ctx)
if err != nil {
return nil, err
}

View File

@ -1,6 +1,7 @@
package idp
import (
"context"
"fmt"
"testing"
"time"
@ -101,7 +102,7 @@ func TestAzureAuthenticate(t *testing.T) {
}
creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate()
_, err := creds.Authenticate(context.Background())
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")

View File

@ -39,12 +39,12 @@ type GoogleWorkspaceCredentials struct {
appMetrics telemetry.AppMetrics
}
func (gc *GoogleWorkspaceCredentials) Authenticate() (JWTToken, error) {
func (gc *GoogleWorkspaceCredentials) Authenticate(_ context.Context) (JWTToken, error) {
return JWTToken{}, nil
}
// NewGoogleWorkspaceManager creates a new instance of the GoogleWorkspaceManager.
func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) {
func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
@ -66,7 +66,7 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te
}
// Create a new Admin SDK Directory service client
adminCredentials, err := getGoogleCredentials(config.ServiceAccountKey)
adminCredentials, err := getGoogleCredentials(ctx, config.ServiceAccountKey)
if err != nil {
return nil, err
}
@ -90,12 +90,12 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
// GetUserDataByID requests user data from Google Workspace via ID.
func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
func (gm *GoogleWorkspaceManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
user, err := gm.usersService.Get(userID).Do()
if err != nil {
return nil, err
@ -112,7 +112,7 @@ func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata App
}
// GetAccount returns all the users for a given profile.
func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) {
func (gm *GoogleWorkspaceManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
users, err := gm.getAllUsers()
if err != nil {
return nil, err
@ -132,7 +132,7 @@ func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, err
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, error) {
func (gm *GoogleWorkspaceManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
users, err := gm.getAllUsers()
if err != nil {
return nil, err
@ -177,13 +177,13 @@ func (gm *GoogleWorkspaceManager) getAllUsers() ([]*UserData, error) {
}
// CreateUser creates a new user in Google Workspace and sends an invitation.
func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, error) {
func (gm *GoogleWorkspaceManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, error) {
func (gm *GoogleWorkspaceManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
user, err := gm.usersService.Get(email).Do()
if err != nil {
return nil, err
@ -201,12 +201,12 @@ func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, err
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (gm *GoogleWorkspaceManager) InviteUserByID(_ string) error {
func (gm *GoogleWorkspaceManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from GoogleWorkspace.
func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error {
func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) error {
if err := gm.usersService.Delete(userID).Do(); err != nil {
return err
}
@ -222,8 +222,8 @@ func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error {
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
// If that fails, it falls back to using the default Google credentials path.
// It returns the retrieved credentials or an error if unsuccessful.
func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error) {
log.Debug("retrieving google credentials from the base64 encoded service account key")
func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*google.Credentials, error) {
log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key")
decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey)
if err != nil {
return nil, fmt.Errorf("failed to decode service account key: %w", err)
@ -239,8 +239,8 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error)
return creds, nil
}
log.Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err)
log.Debug("falling back to default google credentials location")
log.WithContext(ctx).Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err)
log.WithContext(ctx).Debug("falling back to default google credentials location")
creds, err = google.FindDefaultCredentials(
context.Background(),

View File

@ -1,6 +1,7 @@
package idp
import (
"context"
"fmt"
"net/http"
"strings"
@ -16,14 +17,14 @@ const (
// Manager idp manager interface
type Manager interface {
UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error
GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error)
GetAccount(accountId string) ([]*UserData, error)
GetAllAccounts() (map[string][]*UserData, error)
CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error)
GetUserByEmail(email string) ([]*UserData, error)
InviteUserByID(userID string) error
DeleteUser(userID string) error
UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error
GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
GetAccount(ctx context.Context, accountId string) ([]*UserData, error)
GetAllAccounts(ctx context.Context) (map[string][]*UserData, error)
CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
GetUserByEmail(ctx context.Context, email string) ([]*UserData, error)
InviteUserByID(ctx context.Context, userID string) error
DeleteUser(ctx context.Context, userID string) error
}
// ClientConfig defines common client configuration for all IdP manager
@ -51,7 +52,7 @@ type Config struct {
// ManagerCredentials interface that authenticates using the credential of each type of idp
type ManagerCredentials interface {
Authenticate() (JWTToken, error)
Authenticate(ctx context.Context) (JWTToken, error)
}
// ManagerHTTPClient http client interface for API calls
@ -91,7 +92,7 @@ type JWTToken struct {
}
// NewManager returns a new idp manager based on the configuration that it receives
func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
if config.ClientConfig != nil {
config.ClientConfig.Issuer = strings.TrimSuffix(config.ClientConfig.Issuer, "/")
}
@ -175,7 +176,7 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"],
CustomerID: config.ExtraConfig["CustomerId"],
}
return NewGoogleWorkspaceManager(googleClientConfig, appMetrics)
return NewGoogleWorkspaceManager(ctx, googleClientConfig, appMetrics)
case "jumpcloud":
jumpcloudConfig := JumpCloudClientConfig{
APIToken: config.ExtraConfig["ApiToken"],

View File

@ -74,7 +74,7 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM
}
// Authenticate retrieves access token to use the JumpCloud user API.
func (jc *JumpCloudCredentials) Authenticate() (JWTToken, error) {
func (jc *JumpCloudCredentials) Authenticate(_ context.Context) (JWTToken, error) {
return JWTToken{}, nil
}
@ -85,12 +85,12 @@ func (jm *JumpCloudManager) authenticationContext() context.Context {
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
// GetUserDataByID requests user data from JumpCloud via ID.
func (jm *JumpCloudManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
authCtx := jm.authenticationContext()
user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil)
if err != nil {
@ -116,7 +116,7 @@ func (jm *JumpCloudManager) GetUserDataByID(userID string, appMetadata AppMetada
}
// GetAccount returns all the users for a given profile.
func (jm *JumpCloudManager) GetAccount(accountID string) ([]*UserData, error) {
func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
authCtx := jm.authenticationContext()
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
if err != nil {
@ -148,7 +148,7 @@ func (jm *JumpCloudManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (jm *JumpCloudManager) GetAllAccounts() (map[string][]*UserData, error) {
func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
authCtx := jm.authenticationContext()
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
if err != nil {
@ -177,13 +177,13 @@ func (jm *JumpCloudManager) GetAllAccounts() (map[string][]*UserData, error) {
}
// CreateUser creates a new user in JumpCloud Idp and sends an invitation.
func (jm *JumpCloudManager) CreateUser(_, _, _, _ string) (*UserData, error) {
func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (jm *JumpCloudManager) GetUserByEmail(email string) ([]*UserData, error) {
func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
searchFilter := map[string]interface{}{
"searchFilter": map[string]interface{}{
"filter": []string{email},
@ -219,12 +219,12 @@ func (jm *JumpCloudManager) GetUserByEmail(email string) ([]*UserData, error) {
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (jm *JumpCloudManager) InviteUserByID(_ string) error {
func (jm *JumpCloudManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from jumpCloud directory
func (jm *JumpCloudManager) DeleteUser(userID string) error {
func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error {
authCtx := jm.authenticationContext()
_, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil)
if err != nil {

View File

@ -1,6 +1,7 @@
package idp
import (
"context"
"fmt"
"io"
"net/http"
@ -109,7 +110,7 @@ func (kc *KeycloakCredentials) jwtStillValid() bool {
}
// requestJWTToken performs request to get jwt token.
func (kc *KeycloakCredentials) requestJWTToken() (*http.Response, error) {
func (kc *KeycloakCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
data := url.Values{}
data.Set("client_id", kc.clientConfig.ClientID)
data.Set("client_secret", kc.clientConfig.ClientSecret)
@ -122,7 +123,7 @@ func (kc *KeycloakCredentials) requestJWTToken() (*http.Response, error) {
}
req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.Debug("requesting new jwt token for keycloak idp manager")
log.WithContext(ctx).Debug("requesting new jwt token for keycloak idp manager")
resp, err := kc.httpClient.Do(req)
if err != nil {
@ -174,7 +175,7 @@ func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (J
}
// Authenticate retrieves access token to use the keycloak Management API.
func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) {
func (kc *KeycloakCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
kc.mux.Lock()
defer kc.mux.Unlock()
@ -188,7 +189,7 @@ func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) {
return kc.jwtToken, nil
}
resp, err := kc.requestJWTToken()
resp, err := kc.requestJWTToken(ctx)
if err != nil {
return kc.jwtToken, err
}
@ -205,18 +206,18 @@ func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) {
}
// CreateUser creates a new user in keycloak Idp and sends an invite.
func (km *KeycloakManager) CreateUser(_, _, _, _ string) (*UserData, error) {
func (km *KeycloakManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) {
func (km *KeycloakManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
q := url.Values{}
q.Add("email", email)
q.Add("exact", "true")
body, err := km.get("users", q)
body, err := km.get(ctx, "users", q)
if err != nil {
return nil, err
}
@ -240,8 +241,8 @@ func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) {
}
// GetUserDataByID requests user data from keycloak via ID.
func (km *KeycloakManager) GetUserDataByID(userID string, _ AppMetadata) (*UserData, error) {
body, err := km.get("users/"+userID, nil)
func (km *KeycloakManager) GetUserDataByID(ctx context.Context, userID string, _ AppMetadata) (*UserData, error) {
body, err := km.get(ctx, "users/"+userID, nil)
if err != nil {
return nil, err
}
@ -260,8 +261,8 @@ func (km *KeycloakManager) GetUserDataByID(userID string, _ AppMetadata) (*UserD
}
// GetAccount returns all the users for a given account profile.
func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
profiles, err := km.fetchAllUserProfiles()
func (km *KeycloakManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
profiles, err := km.fetchAllUserProfiles(ctx)
if err != nil {
return nil, err
}
@ -283,8 +284,8 @@ func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) {
profiles, err := km.fetchAllUserProfiles()
func (km *KeycloakManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
profiles, err := km.fetchAllUserProfiles(ctx)
if err != nil {
return nil, err
}
@ -303,19 +304,19 @@ func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) {
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (km *KeycloakManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
func (km *KeycloakManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (km *KeycloakManager) InviteUserByID(_ string) error {
func (km *KeycloakManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Keycloak by user ID.
func (km *KeycloakManager) DeleteUser(userID string) error {
jwtToken, err := km.credentials.Authenticate()
func (km *KeycloakManager) DeleteUser(ctx context.Context, userID string) error {
jwtToken, err := km.credentials.Authenticate(ctx)
if err != nil {
return err
}
@ -353,8 +354,8 @@ func (km *KeycloakManager) DeleteUser(userID string) error {
return nil
}
func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) {
totalUsers, err := km.totalUsersCount()
func (km *KeycloakManager) fetchAllUserProfiles(ctx context.Context) ([]keycloakProfile, error) {
totalUsers, err := km.totalUsersCount(ctx)
if err != nil {
return nil, err
}
@ -362,7 +363,7 @@ func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) {
q := url.Values{}
q.Add("max", fmt.Sprint(*totalUsers))
body, err := km.get("users", q)
body, err := km.get(ctx, "users", q)
if err != nil {
return nil, err
}
@ -377,8 +378,8 @@ func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) {
}
// get perform Get requests.
func (km *KeycloakManager) get(resource string, q url.Values) ([]byte, error) {
jwtToken, err := km.credentials.Authenticate()
func (km *KeycloakManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
jwtToken, err := km.credentials.Authenticate(ctx)
if err != nil {
return nil, err
}
@ -414,8 +415,8 @@ func (km *KeycloakManager) get(resource string, q url.Values) ([]byte, error) {
// totalUsersCount returns the total count of all user created.
// Used when fetching all registered accounts with pagination.
func (km *KeycloakManager) totalUsersCount() (*int, error) {
body, err := km.get("users/count", nil)
func (km *KeycloakManager) totalUsersCount(ctx context.Context) (*int, error) {
body, err := km.get(ctx, "users/count", nil)
if err != nil {
return nil, err
}

View File

@ -1,6 +1,7 @@
package idp
import (
"context"
"fmt"
"io"
"strings"
@ -128,7 +129,7 @@ func TestKeycloakRequestJWTToken(t *testing.T) {
helper: testCase.helper,
}
resp, err := creds.requestJWTToken()
resp, err := creds.requestJWTToken(context.Background())
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
@ -294,7 +295,7 @@ func TestKeycloakAuthenticate(t *testing.T) {
}
creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate()
_, err := creds.Authenticate(context.Background())
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")

View File

@ -1,77 +1,79 @@
package idp
import "context"
// MockIDP is a mock implementation of the IDP interface
type MockIDP struct {
UpdateUserAppMetadataFunc func(userId string, appMetadata AppMetadata) error
GetUserDataByIDFunc func(userId string, appMetadata AppMetadata) (*UserData, error)
GetAccountFunc func(accountId string) ([]*UserData, error)
GetAllAccountsFunc func() (map[string][]*UserData, error)
CreateUserFunc func(email, name, accountID, invitedByEmail string) (*UserData, error)
GetUserByEmailFunc func(email string) ([]*UserData, error)
InviteUserByIDFunc func(userID string) error
DeleteUserFunc func(userID string) error
UpdateUserAppMetadataFunc func(ctx context.Context, userId string, appMetadata AppMetadata) error
GetUserDataByIDFunc func(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
GetAccountFunc func(ctx context.Context, accountId string) ([]*UserData, error)
GetAllAccountsFunc func(ctx context.Context) (map[string][]*UserData, error)
CreateUserFunc func(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
GetUserByEmailFunc func(ctx context.Context, email string) ([]*UserData, error)
InviteUserByIDFunc func(ctx context.Context, userID string) error
DeleteUserFunc func(ctx context.Context, userID string) error
}
// UpdateUserAppMetadata is a mock implementation of the IDP interface UpdateUserAppMetadata method
func (m *MockIDP) UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error {
func (m *MockIDP) UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error {
if m.UpdateUserAppMetadataFunc != nil {
return m.UpdateUserAppMetadataFunc(userId, appMetadata)
return m.UpdateUserAppMetadataFunc(ctx, userId, appMetadata)
}
return nil
}
// GetUserDataByID is a mock implementation of the IDP interface GetUserDataByID method
func (m *MockIDP) GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error) {
func (m *MockIDP) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
if m.GetUserDataByIDFunc != nil {
return m.GetUserDataByIDFunc(userId, appMetadata)
return m.GetUserDataByIDFunc(ctx, userId, appMetadata)
}
return nil, nil
}
// GetAccount is a mock implementation of the IDP interface GetAccount method
func (m *MockIDP) GetAccount(accountId string) ([]*UserData, error) {
func (m *MockIDP) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
if m.GetAccountFunc != nil {
return m.GetAccountFunc(accountId)
return m.GetAccountFunc(ctx, accountId)
}
return nil, nil
}
// GetAllAccounts is a mock implementation of the IDP interface GetAllAccounts method
func (m *MockIDP) GetAllAccounts() (map[string][]*UserData, error) {
func (m *MockIDP) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
if m.GetAllAccountsFunc != nil {
return m.GetAllAccountsFunc()
return m.GetAllAccountsFunc(ctx)
}
return nil, nil
}
// CreateUser is a mock implementation of the IDP interface CreateUser method
func (m *MockIDP) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
func (m *MockIDP) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
if m.CreateUserFunc != nil {
return m.CreateUserFunc(email, name, accountID, invitedByEmail)
return m.CreateUserFunc(ctx, email, name, accountID, invitedByEmail)
}
return nil, nil
}
// GetUserByEmail is a mock implementation of the IDP interface GetUserByEmail method
func (m *MockIDP) GetUserByEmail(email string) ([]*UserData, error) {
func (m *MockIDP) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
if m.GetUserByEmailFunc != nil {
return m.GetUserByEmailFunc(email)
return m.GetUserByEmailFunc(ctx, email)
}
return nil, nil
}
// InviteUserByID is a mock implementation of the IDP interface InviteUserByID method
func (m *MockIDP) InviteUserByID(userID string) error {
func (m *MockIDP) InviteUserByID(ctx context.Context, userID string) error {
if m.InviteUserByIDFunc != nil {
return m.InviteUserByIDFunc(userID)
return m.InviteUserByIDFunc(ctx, userID)
}
return nil
}
// DeleteUser is a mock implementation of the IDP interface DeleteUser method
func (m *MockIDP) DeleteUser(userID string) error {
func (m *MockIDP) DeleteUser(ctx context.Context, userID string) error {
if m.DeleteUserFunc != nil {
return m.DeleteUserFunc(userID)
return m.DeleteUserFunc(ctx, userID)
}
return nil
}

View File

@ -94,17 +94,17 @@ func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (*
}
// Authenticate retrieves access token to use the okta user API.
func (oc *OktaCredentials) Authenticate() (JWTToken, error) {
func (oc *OktaCredentials) Authenticate(_ context.Context) (JWTToken, error) {
return JWTToken{}, nil
}
// CreateUser creates a new user in okta Idp and sends an invitation.
func (om *OktaManager) CreateUser(_, _, _, _ string) (*UserData, error) {
func (om *OktaManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
}
// GetUserDataByID requests user data from keycloak via ID.
func (om *OktaManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
func (om *OktaManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
user, resp, err := om.client.User.GetUser(context.Background(), userID)
if err != nil {
return nil, err
@ -132,7 +132,7 @@ func (om *OktaManager) GetUserDataByID(userID string, appMetadata AppMetadata) (
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) {
func (om *OktaManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
user, resp, err := om.client.User.GetUser(context.Background(), url.QueryEscape(email))
if err != nil {
return nil, err
@ -160,7 +160,7 @@ func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) {
}
// GetAccount returns all the users for a given profile.
func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
func (om *OktaManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
users, err := om.getAllUsers()
if err != nil {
return nil, err
@ -180,7 +180,7 @@ func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) {
func (om *OktaManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
users, err := om.getAllUsers()
if err != nil {
return nil, err
@ -242,18 +242,18 @@ func (om *OktaManager) getAllUsers() ([]*UserData, error) {
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
func (om *OktaManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (om *OktaManager) InviteUserByID(_ string) error {
func (om *OktaManager) InviteUserByID(_ context.Context, _ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
}
// DeleteUser from Okta
func (om *OktaManager) DeleteUser(userID string) error {
func (om *OktaManager) DeleteUser(_ context.Context, userID string) error {
resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil)
if err != nil {
return err

View File

@ -1,6 +1,7 @@
package idp
import (
"context"
"fmt"
"io"
"net/http"
@ -149,7 +150,7 @@ func (zc *ZitadelCredentials) jwtStillValid() bool {
}
// requestJWTToken performs request to get jwt token.
func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) {
func (zc *ZitadelCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
data := url.Values{}
data.Set("client_id", zc.clientConfig.ClientID)
data.Set("client_secret", zc.clientConfig.ClientSecret)
@ -163,7 +164,7 @@ func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) {
}
req.Header.Add("content-type", "application/x-www-form-urlencoded")
log.Debug("requesting new jwt token for zitadel idp manager")
log.WithContext(ctx).Debug("requesting new jwt token for zitadel idp manager")
resp, err := zc.httpClient.Do(req)
if err != nil {
@ -215,7 +216,7 @@ func (zc *ZitadelCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JW
}
// Authenticate retrieves access token to use the Zitadel Management API.
func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
func (zc *ZitadelCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
zc.mux.Lock()
defer zc.mux.Unlock()
@ -229,7 +230,7 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
return zc.jwtToken, nil
}
resp, err := zc.requestJWTToken()
resp, err := zc.requestJWTToken(ctx)
if err != nil {
return zc.jwtToken, err
}
@ -246,7 +247,7 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
}
// CreateUser creates a new user in zitadel Idp and sends an invite via Zitadel.
func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
firstLast := strings.SplitN(name, " ", 2)
var addUser = map[string]any{
@ -269,7 +270,7 @@ func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail stri
return nil, err
}
body, err := zm.post("users/human/_import", string(payload))
body, err := zm.post(ctx, "users/human/_import", string(payload))
if err != nil {
return nil, err
}
@ -300,7 +301,7 @@ func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail stri
// GetUserByEmail searches users with a given email.
// If no users have been found, this function returns an empty list.
func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
func (zm *ZitadelManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
searchByEmail := zitadelAttributes{
"queries": {
{
@ -316,7 +317,7 @@ func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
return nil, err
}
body, err := zm.post("users/_search", string(payload))
body, err := zm.post(ctx, "users/_search", string(payload))
if err != nil {
return nil, err
}
@ -340,8 +341,8 @@ func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
}
// GetUserDataByID requests user data from zitadel via ID.
func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
body, err := zm.get("users/"+userID, nil)
func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
body, err := zm.get(ctx, "users/"+userID, nil)
if err != nil {
return nil, err
}
@ -363,8 +364,8 @@ func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata
}
// GetAccount returns all the users for a given profile.
func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
body, err := zm.post("users/_search", "")
func (zm *ZitadelManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
body, err := zm.post(ctx, "users/_search", "")
if err != nil {
return nil, err
}
@ -392,8 +393,8 @@ func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
body, err := zm.post("users/_search", "")
func (zm *ZitadelManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
body, err := zm.post(ctx, "users/_search", "")
if err != nil {
return nil, err
}
@ -419,7 +420,7 @@ func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
// Metadata values are base64 encoded.
func (zm *ZitadelManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
func (zm *ZitadelManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
@ -429,7 +430,7 @@ type inviteUserRequest struct {
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (zm *ZitadelManager) InviteUserByID(userID string) error {
func (zm *ZitadelManager) InviteUserByID(ctx context.Context, userID string) error {
inviteUser := inviteUserRequest{
Email: userID,
}
@ -440,14 +441,14 @@ func (zm *ZitadelManager) InviteUserByID(userID string) error {
}
// don't care about the body in the response
_, err = zm.post(fmt.Sprintf("users/%s/_resend_initialization", userID), string(payload))
_, err = zm.post(ctx, fmt.Sprintf("users/%s/_resend_initialization", userID), string(payload))
return err
}
// DeleteUser from Zitadel
func (zm *ZitadelManager) DeleteUser(userID string) error {
func (zm *ZitadelManager) DeleteUser(ctx context.Context, userID string) error {
resource := fmt.Sprintf("users/%s", userID)
if err := zm.delete(resource); err != nil {
if err := zm.delete(ctx, resource); err != nil {
return err
}
@ -459,8 +460,8 @@ func (zm *ZitadelManager) DeleteUser(userID string) error {
}
// post perform Post requests.
func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
jwtToken, err := zm.credentials.Authenticate()
func (zm *ZitadelManager) post(ctx context.Context, resource string, body string) ([]byte, error) {
jwtToken, err := zm.credentials.Authenticate(ctx)
if err != nil {
return nil, err
}
@ -495,8 +496,8 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
}
// delete perform Delete requests.
func (zm *ZitadelManager) delete(resource string) error {
jwtToken, err := zm.credentials.Authenticate()
func (zm *ZitadelManager) delete(ctx context.Context, resource string) error {
jwtToken, err := zm.credentials.Authenticate(ctx)
if err != nil {
return err
}
@ -531,8 +532,8 @@ func (zm *ZitadelManager) delete(resource string) error {
}
// get perform Get requests.
func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) {
jwtToken, err := zm.credentials.Authenticate()
func (zm *ZitadelManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
jwtToken, err := zm.credentials.Authenticate(ctx)
if err != nil {
return nil, err
}

View File

@ -1,6 +1,7 @@
package idp
import (
"context"
"fmt"
"io"
"strings"
@ -108,7 +109,7 @@ func TestZitadelRequestJWTToken(t *testing.T) {
helper: testCase.helper,
}
resp, err := creds.requestJWTToken()
resp, err := creds.requestJWTToken(context.Background())
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
@ -274,7 +275,7 @@ func TestZitadelAuthenticate(t *testing.T) {
}
creds.jwtToken.expiresInTime = testCase.inputExpireToken
_, err := creds.Authenticate()
_, err := creds.Authenticate(context.Background())
if err != nil {
if testCase.expectedFuncExitErrDiff != nil {
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")

View File

@ -1,9 +1,10 @@
package server
import (
"context"
"errors"
"github.com/google/martian/v3/log"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
)
@ -19,22 +20,22 @@ import (
//
// Returns:
// - error: An error if any occurred during the process, otherwise returns nil
func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error {
ok, err := am.GroupValidation(accountID, groups)
func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error {
ok, err := am.GroupValidation(ctx, accountID, groups)
if err != nil {
log.Debugf("error validating groups: %s", err.Error())
log.WithContext(ctx).Debugf("error validating groups: %s", err.Error())
return err
}
if !ok {
log.Debugf("invalid groups")
log.WithContext(ctx).Debugf("invalid groups")
return errors.New("invalid groups")
}
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
a, err := am.Store.GetAccountByUser(userID)
a, err := am.Store.GetAccountByUser(ctx, userID)
if err != nil {
return err
}
@ -48,14 +49,14 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID strin
a.Settings.Extra = extra
}
extra.IntegratedValidatorGroups = groups
return am.Store.SaveAccount(a)
return am.Store.SaveAccount(ctx, a)
}
func (am *DefaultAccountManager) GroupValidation(accountId string, groups []string) (bool, error) {
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) {
if len(groups) == 0 {
return true, nil
}
accountsGroups, err := am.ListGroups(accountId)
accountsGroups, err := am.ListGroups(ctx, accountId)
if err != nil {
return false, err
}

View File

@ -1,6 +1,8 @@
package integrated_validator
import (
"context"
"github.com/netbirdio/netbird/management/server/account"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@ -8,12 +10,12 @@ import (
// IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface {
ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error)
PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
ValidateExtraSettings(ctx context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error)
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error)
PeerDeleted(accountID, peerID string) error
PeerDeleted(ctx context.Context, accountID, peerID string) error
SetPeerInvalidationListener(fn func(accountID string))
Stop()
Stop(ctx context.Context)
}

View File

@ -2,6 +2,7 @@ package jwtclaims
import (
"bytes"
"context"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
@ -69,8 +70,8 @@ type JWTValidator struct {
}
// NewJWTValidator constructor
func NewJWTValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) {
keys, err := getPemKeys(keysLocation)
func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) {
keys, err := getPemKeys(ctx, keysLocation)
if err != nil {
return nil, err
}
@ -102,19 +103,19 @@ func NewJWTValidator(issuer string, audienceList []string, keysLocation string,
lock.Lock()
defer lock.Unlock()
refreshedKeys, err := getPemKeys(keysLocation)
refreshedKeys, err := getPemKeys(ctx, keysLocation)
if err != nil {
log.Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
refreshedKeys = keys
}
log.Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
keys = refreshedKeys
}
}
cert, err := getPemCert(token, keys)
cert, err := getPemCert(ctx, token, keys)
if err != nil {
return nil, err
}
@ -136,19 +137,19 @@ func NewJWTValidator(issuer string, audienceList []string, keysLocation string,
}
// ValidateAndParse validates the token and returns the parsed token
func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
// If the token is empty...
if token == "" {
// Check if it was required
if m.options.CredentialsOptional {
log.Debugf("no credentials found (CredentialsOptional=true)")
log.WithContext(ctx).Debugf("no credentials found (CredentialsOptional=true)")
// No error, just no token (and that is ok given that CredentialsOptional is true)
return nil, nil //nolint:nilnil
}
// If we get here, the required token is missing
errorMsg := "required authorization token not found"
log.Debugf(" Error: No credentials found (CredentialsOptional=false)")
log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)")
return nil, fmt.Errorf(errorMsg)
}
@ -157,7 +158,7 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
// Check if there was an error in parsing...
if err != nil {
log.Errorf("error parsing token: %v", err)
log.WithContext(ctx).Errorf("error parsing token: %v", err)
return nil, fmt.Errorf("Error parsing token: %w", err)
}
@ -165,14 +166,14 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s",
m.options.SigningMethod.Alg(),
parsedToken.Header["alg"])
log.Debugf("error validating token algorithm: %s", errorMsg)
log.WithContext(ctx).Debugf("error validating token algorithm: %s", errorMsg)
return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg)
}
// Check if the parsed token is valid...
if !parsedToken.Valid {
errorMsg := "token is invalid"
log.Debugf(errorMsg)
log.WithContext(ctx).Debugf(errorMsg)
return nil, errors.New(errorMsg)
}
@ -184,7 +185,7 @@ func (jwks *Jwks) stillValid() bool {
return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
}
func getPemKeys(keysLocation string) (*Jwks, error) {
func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) {
resp, err := http.Get(keysLocation)
if err != nil {
return nil, err
@ -198,13 +199,13 @@ func getPemKeys(keysLocation string) (*Jwks, error) {
}
cacheControlHeader := resp.Header.Get("Cache-Control")
expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader)
expiresIn := getMaxAgeFromCacheHeader(ctx, cacheControlHeader)
jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
return jwks, err
}
func getPemCert(token *jwt.Token, jwks *Jwks) (string, error) {
func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, error) {
// todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time
cert := ""
@ -217,7 +218,7 @@ func getPemCert(token *jwt.Token, jwks *Jwks) (string, error) {
cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
return cert, nil
}
log.Debugf("generating validation pem from JWK")
log.WithContext(ctx).Debugf("generating validation pem from JWK")
return generatePemFromJWK(jwks.Keys[k])
}
@ -284,7 +285,7 @@ func convertExponentStringToInt(stringExponent string) (int, error) {
}
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
func getMaxAgeFromCacheHeader(cacheControl string) int {
func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
// Split into individual directives
directives := strings.Split(cacheControl, ",")
@ -295,7 +296,7 @@ func getMaxAgeFromCacheHeader(cacheControl string) int {
maxAgeStr := strings.TrimPrefix(directive, "max-age=")
maxAge, err := strconv.Atoi(maxAgeStr)
if err != nil {
log.Debugf("error parsing max-age: %v", err)
log.WithContext(ctx).Debugf("error parsing max-age: %v", err)
return 0
}

View File

@ -406,7 +406,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := NewTestStoreFromJson(config.Datadir)
store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil {
return nil, "", err
}
@ -414,7 +414,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
peersUpdateManager := NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted",
accountManager, err := BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{})
if err != nil {
return nil, "", err
@ -422,7 +422,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)
mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)
if err != nil {
return nil, "", err
}

View File

@ -451,11 +451,11 @@ var _ = Describe("Management service", func() {
type MocIntegratedValidator struct {
}
func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil
}
func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
return update, nil
}
@ -467,15 +467,15 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[s
return validatedPeers, nil
}
func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
return peer
}
func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
return false, false, nil
}
func (MocIntegratedValidator) PeerDeleted(_, _ string) error {
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
return nil
}
@ -483,7 +483,7 @@ func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)
}
func (MocIntegratedValidator) Stop() {}
func (MocIntegratedValidator) Stop(_ context.Context) {}
func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse {
defer GinkgoRecover()
@ -534,20 +534,20 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
Expect(err).NotTo(HaveOccurred())
s := grpc.NewServer()
store, _, err := server.NewTestStoreFromJson(config.Datadir)
store, _, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil {
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
}
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted",
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{})
if err != nil {
log.Fatalf("failed creating a manager: %v", err)
}
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
Expect(err).NotTo(HaveOccurred())
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {

View File

@ -46,7 +46,7 @@ type properties map[string]interface{}
// DataSource metric data source
type DataSource interface {
GetAllAccounts() []*server.Account
GetAllAccounts(ctx context.Context) []*server.Account
GetStoreEngine() server.StoreEngine
}
@ -81,29 +81,29 @@ func NewWorker(ctx context.Context, id string, dataSource DataSource, connManage
}
// Run runs the metrics worker
func (w *Worker) Run() {
func (w *Worker) Run(ctx context.Context) {
pushTicker := time.NewTicker(defaultPushInterval)
for {
select {
case <-w.ctx.Done():
return
case <-pushTicker.C:
err := w.sendMetrics()
err := w.sendMetrics(ctx)
if err != nil {
log.Error(err)
log.WithContext(ctx).Error(err)
}
w.lastRun = time.Now()
}
}
}
func (w *Worker) sendMetrics() error {
func (w *Worker) sendMetrics(ctx context.Context) error {
apiKey, err := getAPIKey(w.ctx)
if err != nil {
return err
}
payload := w.generatePayload(apiKey)
payload := w.generatePayload(ctx, apiKey)
payloadString, err := buildMetricsPayload(payload)
if err != nil {
@ -125,7 +125,7 @@ func (w *Worker) sendMetrics() error {
defer func() {
err = jobResp.Body.Close()
if err != nil {
log.Errorf("error while closing update metrics response body: %v", err)
log.WithContext(ctx).Errorf("error while closing update metrics response body: %v", err)
}
}()
@ -133,15 +133,15 @@ func (w *Worker) sendMetrics() error {
return fmt.Errorf("unable to push anonymous metrics, got statusCode %d", jobResp.StatusCode)
}
log.Infof("sent anonymous metrics, next push will happen in %s. "+
log.WithContext(ctx).Infof("sent anonymous metrics, next push will happen in %s. "+
"You can disable these metrics by running with flag --disable-anonymous-metrics,"+
" see more information at https://netbird.io/docs/FAQ/metrics-collection", defaultPushInterval)
return nil
}
func (w *Worker) generatePayload(apiKey string) pushPayload {
properties := w.generateProperties()
func (w *Worker) generatePayload(ctx context.Context, apiKey string) pushPayload {
properties := w.generateProperties(ctx)
return pushPayload{
APIKey: apiKey,
@ -152,7 +152,7 @@ func (w *Worker) generatePayload(apiKey string) pushPayload {
}
}
func (w *Worker) generateProperties() properties {
func (w *Worker) generateProperties(ctx context.Context) properties {
var (
uptime float64
accounts int
@ -192,7 +192,7 @@ func (w *Worker) generateProperties() properties {
connections := w.connManager.GetAllConnectedPeers()
version = nbversion.NetbirdVersion()
for _, account := range w.dataSource.GetAllAccounts() {
for _, account := range w.dataSource.GetAllAccounts(ctx) {
accounts++
if account.Settings.PeerLoginExpirationEnabled {
@ -342,7 +342,7 @@ func getAPIKey(ctx context.Context) (string, error) {
defer func() {
err = response.Body.Close()
if err != nil {
log.Errorf("error while closing metrics token response body: %v", err)
log.WithContext(ctx).Errorf("error while closing metrics token response body: %v", err)
}
}()

View File

@ -1,6 +1,7 @@
package metrics
import (
"context"
"testing"
nbdns "github.com/netbirdio/netbird/dns"
@ -21,7 +22,7 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} {
}
// GetAllAccounts returns a list of *server.Account for use in tests with predefined information
func (mockDatasource) GetAllAccounts() []*server.Account {
func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account {
return []*server.Account{
{
Id: "1",
@ -188,7 +189,7 @@ func TestGenerateProperties(t *testing.T) {
connManager: ds,
}
properties := worker.generateProperties()
properties := worker.generateProperties(context.Background())
if properties["accounts"] != 2 {
t.Errorf("expected 2 accounts, got %d", properties["accounts"])

View File

@ -1,6 +1,7 @@
package migration
import (
"context"
"database/sql"
"encoding/gob"
"encoding/json"
@ -16,7 +17,7 @@ import (
// MigrateFieldFromGobToJSON migrates a column from Gob encoding to JSON encoding.
// T is the type of the model that contains the field to be migrated.
// S is the type of the field to be migrated.
func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) error {
func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, fieldName string) error {
oldColumnName := fieldName
newColumnName := fieldName + "_tmp"
@ -24,7 +25,7 @@ func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) erro
var model T
if !db.Migrator().HasTable(&model) {
log.Debugf("Table for %T does not exist, no migration needed", model)
log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", model)
return nil
}
@ -38,7 +39,7 @@ func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) erro
var sqliteItem sql.NullString
if err := db.Model(model).Select(oldColumnName).First(&sqliteItem).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Debugf("No records in table %s, no migration needed", tableName)
log.WithContext(ctx).Debugf("No records in table %s, no migration needed", tableName)
return nil
}
return fmt.Errorf("fetch first record: %w", err)
@ -51,7 +52,7 @@ func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) erro
err = json.Unmarshal([]byte(item), &js)
// if the item is JSON parsable or an empty string it can not be gob encoded
if err == nil || !errors.As(err, &syntaxError) || item == "" {
log.Debugf("No migration needed for %s, %s", tableName, fieldName)
log.WithContext(ctx).Debugf("No migration needed for %s, %s", tableName, fieldName)
return nil
}
@ -99,14 +100,14 @@ func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) erro
return err
}
log.Infof("Migration of %s.%s from gob to json completed", tableName, fieldName)
log.WithContext(ctx).Infof("Migration of %s.%s from gob to json completed", tableName, fieldName)
return nil
}
// MigrateNetIPFieldFromBlobToJSON migrates a Net IP column from Blob encoding to JSON encoding.
// T is the type of the model that contains the field to be migrated.
func MigrateNetIPFieldFromBlobToJSON[T any](db *gorm.DB, fieldName string, indexName string) error {
func MigrateNetIPFieldFromBlobToJSON[T any](ctx context.Context, db *gorm.DB, fieldName string, indexName string) error {
oldColumnName := fieldName
newColumnName := fieldName + "_tmp"
@ -138,7 +139,7 @@ func MigrateNetIPFieldFromBlobToJSON[T any](db *gorm.DB, fieldName string, index
var syntaxError *json.SyntaxError
err = json.Unmarshal([]byte(item.String), &js)
if err == nil || !errors.As(err, &syntaxError) {
log.Debugf("No migration needed for %s, %s", tableName, fieldName)
log.WithContext(ctx).Debugf("No migration needed for %s, %s", tableName, fieldName)
return nil
}
}
@ -169,7 +170,7 @@ func MigrateNetIPFieldFromBlobToJSON[T any](db *gorm.DB, fieldName string, index
columnIpValue := net.IP(blobValue)
if net.ParseIP(columnIpValue.String()) == nil {
log.Debugf("failed to parse %s as ip, fallback to ipv6 loopback", oldColumnName)
log.WithContext(ctx).Debugf("failed to parse %s as ip, fallback to ipv6 loopback", oldColumnName)
columnIpValue = net.IPv6loopback
}

View File

@ -1,6 +1,7 @@
package migration_test
import (
"context"
"encoding/gob"
"net"
"strings"
@ -30,7 +31,7 @@ func setupDatabase(t *testing.T) *gorm.DB {
func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) {
db := setupDatabase(t)
err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net")
err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net")
require.NoError(t, err, "Migration should not fail for an empty database")
}
@ -63,7 +64,7 @@ func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) {
err = gob.NewDecoder(strings.NewReader(gobStr)).Decode(&ipnet)
require.NoError(t, err, "Failed to decode Gob data")
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net")
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net")
require.NoError(t, err, "Migration should not fail with Gob data")
var jsonStr string
@ -83,7 +84,7 @@ func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) {
err = db.Save(&server.Account{Network: &server.Network{Net: *ipnet}}).Error
require.NoError(t, err, "Failed to insert JSON data")
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net")
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net")
require.NoError(t, err, "Migration should not fail with JSON data")
var jsonStr string
@ -93,7 +94,7 @@ func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) {
func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) {
db := setupDatabase(t)
err := migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
err := migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "ip", "idx_peers_account_id_ip")
require.NoError(t, err, "Migration should not fail for an empty database")
}
@ -130,7 +131,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
err = db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&blobValue).Error
assert.NoError(t, err, "Failed to fetch blob data")
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "")
require.NoError(t, err, "Migration should not fail with net.IP blob data")
var jsonStr string
@ -152,7 +153,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
).Error
require.NoError(t, err, "Failed to insert JSON data")
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "")
require.NoError(t, err, "Migration should not fail with net.IP JSON data")
var jsonStr string

View File

@ -1,6 +1,7 @@
package mock_server
import (
"context"
"net"
"net/netip"
"time"
@ -21,94 +22,95 @@ import (
)
type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error)
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(accountID string) ([]*server.User, error)
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
DeletePeerFunc func(accountID, peerKey, userID string) error
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
SaveGroupFunc func(accountID, userID string, group *group.Group) error
DeleteGroupFunc func(accountID, userId, groupID string) error
ListGroupsFunc func(accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(accountID, groupID, peerID string) error
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
DeleteRuleFunc func(accountID, ruleID, userID string) error
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
DeletePolicyFunc func(accountID, policyID, userID string) error
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
MarkPATUsedFunc func(pat string) error
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
GetRouteFunc func(accountID string, routeID route.ID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID string, userID string, route *route.Route) error
DeleteRouteFunc func(accountID string, routeID route.ID, userID string) error
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
DeleteAccountFunc func(accountID, userID string) error
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error)
AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*group.Group, error)
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error)
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error)
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) error
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
MarkPATUsedFunc func(ctx context.Context, pat string) error
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error)
SaveSetupKeyFunc func(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error)
SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error)
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
GetDNSDomainFunc func() string
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error)
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error)
LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() server.ExternalCacheManager
GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
GroupValidationFunc func(accountId string, groups []string) (bool, error)
SyncPeerMetaFunc func(peerPubKey string, meta nbpeer.PeerSystemMeta) error
UpdateIntegratedValidatorGroupsFunc func(ctx context.Context, accountID string, userID string, groups []string) error
GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error)
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
}
func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(peerPubKey, meta, realIP)
return am.SyncAndMarkPeerFunc(ctx, peerPubKey, meta, realIP)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
func (am *MockAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
func (am *MockAccountManager) CancelPeerRoutines(_ context.Context, peer *nbpeer.Peer) error {
// TODO implement me
panic("implement me")
}
@ -122,43 +124,43 @@ func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[st
}
// GetGroup mock implementation of GetGroup from server.AccountManager interface
func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*group.Group, error) {
func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, userID string) (*group.Group, error) {
if am.GetGroupFunc != nil {
return am.GetGroupFunc(accountId, groupID, userID)
return am.GetGroupFunc(ctx, accountId, groupID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetGroup is not implemented")
}
// GetAllGroups mock implementation of GetAllGroups from server.AccountManager interface
func (am *MockAccountManager) GetAllGroups(accountID, userID string) ([]*group.Group, error) {
func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*group.Group, error) {
if am.GetAllGroupsFunc != nil {
return am.GetAllGroupsFunc(accountID, userID)
return am.GetAllGroupsFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAllGroups is not implemented")
}
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
func (am *MockAccountManager) GetUsersFromAccount(accountID string, userID string) ([]*server.UserInfo, error) {
func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*server.UserInfo, error) {
if am.GetUsersFromAccountFunc != nil {
return am.GetUsersFromAccountFunc(accountID, userID)
return am.GetUsersFromAccountFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount is not implemented")
}
// DeletePeer mock implementation of DeletePeer from server.AccountManager interface
func (am *MockAccountManager) DeletePeer(accountID, peerID, userID string) error {
func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
if am.DeletePeerFunc != nil {
return am.DeletePeerFunc(accountID, peerID, userID)
return am.DeletePeerFunc(ctx, accountID, peerID, userID)
}
return status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented")
}
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
func (am *MockAccountManager) GetOrCreateAccountByUser(
userId, domain string,
ctx context.Context, userId, domain string,
) (*server.Account, error) {
if am.GetOrCreateAccountByUserFunc != nil {
return am.GetOrCreateAccountByUserFunc(userId, domain)
return am.GetOrCreateAccountByUserFunc(ctx, userId, domain)
}
return nil, status.Errorf(
codes.Unimplemented,
@ -168,6 +170,7 @@ func (am *MockAccountManager) GetOrCreateAccountByUser(
// CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface
func (am *MockAccountManager) CreateSetupKey(
ctx context.Context,
accountID string,
keyName string,
keyType server.SetupKeyType,
@ -178,17 +181,17 @@ func (am *MockAccountManager) CreateSetupKey(
ephemeral bool,
) (*server.SetupKey, error) {
if am.CreateSetupKeyFunc != nil {
return am.CreateSetupKeyFunc(accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID, ephemeral)
return am.CreateSetupKeyFunc(ctx, accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID, ephemeral)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
}
// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface
func (am *MockAccountManager) GetAccountByUserOrAccountID(
userId, accountId, domain string,
ctx context.Context, userId, accountId, domain string,
) (*server.Account, error) {
if am.GetAccountByUserOrAccountIdFunc != nil {
return am.GetAccountByUserOrAccountIdFunc(userId, accountId, domain)
return am.GetAccountByUserOrAccountIdFunc(ctx, userId, accountId, domain)
}
return nil, status.Errorf(
codes.Unimplemented,
@ -197,391 +200,392 @@ func (am *MockAccountManager) GetAccountByUserOrAccountID(
}
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *server.Account) error {
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *server.Account) error {
if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(peerKey, connected, realIP)
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP)
}
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
if am.GetAccountFromPATFunc != nil {
return am.GetAccountFromPATFunc(pat)
return am.GetAccountFromPATFunc(ctx, pat)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented")
}
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
func (am *MockAccountManager) DeleteAccount(accountID, userID string) error {
func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
if am.DeleteAccountFunc != nil {
return am.DeleteAccountFunc(accountID, userID)
return am.DeleteAccountFunc(ctx, accountID, userID)
}
return status.Errorf(codes.Unimplemented, "method DeleteAccount is not implemented")
}
// MarkPATUsed mock implementation of MarkPATUsed from server.AccountManager interface
func (am *MockAccountManager) MarkPATUsed(pat string) error {
func (am *MockAccountManager) MarkPATUsed(ctx context.Context, pat string) error {
if am.MarkPATUsedFunc != nil {
return am.MarkPATUsedFunc(pat)
return am.MarkPATUsedFunc(ctx, pat)
}
return status.Errorf(codes.Unimplemented, "method MarkPATUsed is not implemented")
}
// CreatePAT mock implementation of GetPAT from server.AccountManager interface
func (am *MockAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
if am.CreatePATFunc != nil {
return am.CreatePATFunc(accountID, initiatorUserID, targetUserID, name, expiresIn)
return am.CreatePATFunc(ctx, accountID, initiatorUserID, targetUserID, name, expiresIn)
}
return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented")
}
// DeletePAT mock implementation of DeletePAT from server.AccountManager interface
func (am *MockAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
func (am *MockAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if am.DeletePATFunc != nil {
return am.DeletePATFunc(accountID, initiatorUserID, targetUserID, tokenID)
return am.DeletePATFunc(ctx, accountID, initiatorUserID, targetUserID, tokenID)
}
return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented")
}
// GetPAT mock implementation of GetPAT from server.AccountManager interface
func (am *MockAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
if am.GetPATFunc != nil {
return am.GetPATFunc(accountID, initiatorUserID, targetUserID, tokenID)
return am.GetPATFunc(ctx, accountID, initiatorUserID, targetUserID, tokenID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented")
}
// GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface
func (am *MockAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
if am.GetAllPATsFunc != nil {
return am.GetAllPATsFunc(accountID, initiatorUserID, targetUserID)
return am.GetAllPATsFunc(ctx, accountID, initiatorUserID, targetUserID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented")
}
// GetNetworkMap mock implementation of GetNetworkMap from server.AccountManager interface
func (am *MockAccountManager) GetNetworkMap(peerKey string) (*server.NetworkMap, error) {
func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) (*server.NetworkMap, error) {
if am.GetNetworkMapFunc != nil {
return am.GetNetworkMapFunc(peerKey)
return am.GetNetworkMapFunc(ctx, peerKey)
}
return nil, status.Errorf(codes.Unimplemented, "method GetNetworkMap is not implemented")
}
// GetPeerNetwork mock implementation of GetPeerNetwork from server.AccountManager interface
func (am *MockAccountManager) GetPeerNetwork(peerKey string) (*server.Network, error) {
func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string) (*server.Network, error) {
if am.GetPeerNetworkFunc != nil {
return am.GetPeerNetworkFunc(peerKey)
return am.GetPeerNetworkFunc(ctx, peerKey)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeerNetwork is not implemented")
}
// AddPeer mock implementation of AddPeer from server.AccountManager interface
func (am *MockAccountManager) AddPeer(
ctx context.Context,
setupKey string,
userId string,
peer *nbpeer.Peer,
) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
if am.AddPeerFunc != nil {
return am.AddPeerFunc(setupKey, userId, peer)
return am.AddPeerFunc(ctx, setupKey, userId, peer)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented")
}
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*group.Group, error) {
func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*group.Group, error) {
if am.GetGroupFunc != nil {
return am.GetGroupByNameFunc(accountID, groupName)
return am.GetGroupByNameFunc(ctx, accountID, groupName)
}
return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented")
}
// SaveGroup mock implementation of SaveGroup from server.AccountManager interface
func (am *MockAccountManager) SaveGroup(accountID, userID string, group *group.Group) error {
func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID string, group *group.Group) error {
if am.SaveGroupFunc != nil {
return am.SaveGroupFunc(accountID, userID, group)
return am.SaveGroupFunc(ctx, accountID, userID, group)
}
return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented")
}
// DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface
func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) error {
func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
if am.DeleteGroupFunc != nil {
return am.DeleteGroupFunc(accountId, userId, groupID)
return am.DeleteGroupFunc(ctx, accountId, userId, groupID)
}
return status.Errorf(codes.Unimplemented, "method DeleteGroup is not implemented")
}
// ListGroups mock implementation of ListGroups from server.AccountManager interface
func (am *MockAccountManager) ListGroups(accountID string) ([]*group.Group, error) {
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
if am.ListGroupsFunc != nil {
return am.ListGroupsFunc(accountID)
return am.ListGroupsFunc(ctx, accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
}
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
func (am *MockAccountManager) GroupAddPeer(accountID, groupID, peerID string) error {
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
if am.GroupAddPeerFunc != nil {
return am.GroupAddPeerFunc(accountID, groupID, peerID)
return am.GroupAddPeerFunc(ctx, accountID, groupID, peerID)
}
return status.Errorf(codes.Unimplemented, "method GroupAddPeer is not implemented")
}
// GroupDeletePeer mock implementation of GroupDeletePeer from server.AccountManager interface
func (am *MockAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error {
func (am *MockAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
if am.GroupDeletePeerFunc != nil {
return am.GroupDeletePeerFunc(accountID, groupID, peerID)
return am.GroupDeletePeerFunc(ctx, accountID, groupID, peerID)
}
return status.Errorf(codes.Unimplemented, "method GroupDeletePeer is not implemented")
}
// DeleteRule mock implementation of DeleteRule from server.AccountManager interface
func (am *MockAccountManager) DeleteRule(accountID, ruleID, userID string) error {
func (am *MockAccountManager) DeleteRule(ctx context.Context, accountID, ruleID, userID string) error {
if am.DeleteRuleFunc != nil {
return am.DeleteRuleFunc(accountID, ruleID, userID)
return am.DeleteRuleFunc(ctx, accountID, ruleID, userID)
}
return status.Errorf(codes.Unimplemented, "method DeleteRule is not implemented")
}
// GetPolicy mock implementation of GetPolicy from server.AccountManager interface
func (am *MockAccountManager) GetPolicy(accountID, policyID, userID string) (*server.Policy, error) {
func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) {
if am.GetPolicyFunc != nil {
return am.GetPolicyFunc(accountID, policyID, userID)
return am.GetPolicyFunc(ctx, accountID, policyID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPolicy is not implemented")
}
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
func (am *MockAccountManager) SavePolicy(accountID, userID string, policy *server.Policy) error {
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) error {
if am.SavePolicyFunc != nil {
return am.SavePolicyFunc(accountID, userID, policy)
return am.SavePolicyFunc(ctx, accountID, userID, policy)
}
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
}
// DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface
func (am *MockAccountManager) DeletePolicy(accountID, policyID, userID string) error {
func (am *MockAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
if am.DeletePolicyFunc != nil {
return am.DeletePolicyFunc(accountID, policyID, userID)
return am.DeletePolicyFunc(ctx, accountID, policyID, userID)
}
return status.Errorf(codes.Unimplemented, "method DeletePolicy is not implemented")
}
// ListPolicies mock implementation of ListPolicies from server.AccountManager interface
func (am *MockAccountManager) ListPolicies(accountID, userID string) ([]*server.Policy, error) {
func (am *MockAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*server.Policy, error) {
if am.ListPoliciesFunc != nil {
return am.ListPoliciesFunc(accountID, userID)
return am.ListPoliciesFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListPolicies is not implemented")
}
// UpdatePeerMeta mock implementation of UpdatePeerMeta from server.AccountManager interface
func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta nbpeer.PeerSystemMeta) error {
func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error {
if am.UpdatePeerMetaFunc != nil {
return am.UpdatePeerMetaFunc(peerID, meta)
return am.UpdatePeerMetaFunc(ctx, peerID, meta)
}
return status.Errorf(codes.Unimplemented, "method UpdatePeerMeta is not implemented")
}
// GetUser mock implementation of GetUser from server.AccountManager interface
func (am *MockAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*server.User, error) {
func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) {
if am.GetUserFunc != nil {
return am.GetUserFunc(claims)
return am.GetUserFunc(ctx, claims)
}
return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented")
}
func (am *MockAccountManager) ListUsers(accountID string) ([]*server.User, error) {
func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*server.User, error) {
if am.ListUsersFunc != nil {
return am.ListUsersFunc(accountID)
return am.ListUsersFunc(ctx, accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented")
}
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
func (am *MockAccountManager) UpdatePeerSSHKey(peerID string, sshKey string) error {
func (am *MockAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
if am.UpdatePeerSSHKeyFunc != nil {
return am.UpdatePeerSSHKeyFunc(peerID, sshKey)
return am.UpdatePeerSSHKeyFunc(ctx, peerID, sshKey)
}
return status.Errorf(codes.Unimplemented, "method UpdatePeerSSHKey is not implemented")
}
// UpdatePeer mocks UpdatePeerFunc function of the account manager
func (am *MockAccountManager) UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) {
func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) {
if am.UpdatePeerFunc != nil {
return am.UpdatePeerFunc(accountID, userID, peer)
return am.UpdatePeerFunc(ctx, accountID, userID, peer)
}
return nil, status.Errorf(codes.Unimplemented, "method UpdatePeer is not implemented")
}
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
func (am *MockAccountManager) CreateRoute(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
if am.CreateRouteFunc != nil {
return am.CreateRouteFunc(accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID, keepRoute)
return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID, keepRoute)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
}
// GetRoute mock implementation of GetRoute from server.AccountManager interface
func (am *MockAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) {
func (am *MockAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
if am.GetRouteFunc != nil {
return am.GetRouteFunc(accountID, routeID, userID)
return am.GetRouteFunc(ctx, accountID, routeID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetRoute is not implemented")
}
// SaveRoute mock implementation of SaveRoute from server.AccountManager interface
func (am *MockAccountManager) SaveRoute(accountID string, userID string, route *route.Route) error {
func (am *MockAccountManager) SaveRoute(ctx context.Context, accountID string, userID string, route *route.Route) error {
if am.SaveRouteFunc != nil {
return am.SaveRouteFunc(accountID, userID, route)
return am.SaveRouteFunc(ctx, accountID, userID, route)
}
return status.Errorf(codes.Unimplemented, "method SaveRoute is not implemented")
}
// DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface
func (am *MockAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error {
func (am *MockAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
if am.DeleteRouteFunc != nil {
return am.DeleteRouteFunc(accountID, routeID, userID)
return am.DeleteRouteFunc(ctx, accountID, routeID, userID)
}
return status.Errorf(codes.Unimplemented, "method DeleteRoute is not implemented")
}
// ListRoutes mock implementation of ListRoutes from server.AccountManager interface
func (am *MockAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
func (am *MockAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
if am.ListRoutesFunc != nil {
return am.ListRoutesFunc(accountID, userID)
return am.ListRoutesFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListRoutes is not implemented")
}
// SaveSetupKey mocks SaveSetupKey of the AccountManager interface
func (am *MockAccountManager) SaveSetupKey(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) {
func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) {
if am.SaveSetupKeyFunc != nil {
return am.SaveSetupKeyFunc(accountID, key, userID)
return am.SaveSetupKeyFunc(ctx, accountID, key, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method SaveSetupKey is not implemented")
}
// GetSetupKey mocks GetSetupKey of the AccountManager interface
func (am *MockAccountManager) GetSetupKey(accountID, userID, keyID string) (*server.SetupKey, error) {
func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) {
if am.GetSetupKeyFunc != nil {
return am.GetSetupKeyFunc(accountID, userID, keyID)
return am.GetSetupKeyFunc(ctx, accountID, userID, keyID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetSetupKey is not implemented")
}
// ListSetupKeys mocks ListSetupKeys of the AccountManager interface
func (am *MockAccountManager) ListSetupKeys(accountID, userID string) ([]*server.SetupKey, error) {
func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) {
if am.ListSetupKeysFunc != nil {
return am.ListSetupKeysFunc(accountID, userID)
return am.ListSetupKeysFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListSetupKeys is not implemented")
}
// SaveUser mocks SaveUser of the AccountManager interface
func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.User) (*server.UserInfo, error) {
func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) {
if am.SaveUserFunc != nil {
return am.SaveUserFunc(accountID, userID, user)
return am.SaveUserFunc(ctx, accountID, userID, user)
}
return nil, status.Errorf(codes.Unimplemented, "method SaveUser is not implemented")
}
// SaveOrAddUser mocks SaveOrAddUser of the AccountManager interface
func (am *MockAccountManager) SaveOrAddUser(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) {
func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) {
if am.SaveOrAddUserFunc != nil {
return am.SaveOrAddUserFunc(accountID, userID, user, addIfNotExists)
return am.SaveOrAddUserFunc(ctx, accountID, userID, user, addIfNotExists)
}
return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented")
}
// DeleteUser mocks DeleteUser of the AccountManager interface
func (am *MockAccountManager) DeleteUser(accountID string, initiatorUserID string, targetUserID string) error {
func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
if am.DeleteUserFunc != nil {
return am.DeleteUserFunc(accountID, initiatorUserID, targetUserID)
return am.DeleteUserFunc(ctx, accountID, initiatorUserID, targetUserID)
}
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
}
func (am *MockAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error {
func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
if am.InviteUserFunc != nil {
return am.InviteUserFunc(accountID, initiatorUserID, targetUserID)
return am.InviteUserFunc(ctx, accountID, initiatorUserID, targetUserID)
}
return status.Errorf(codes.Unimplemented, "method InviteUser is not implemented")
}
// GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface
func (am *MockAccountManager) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
func (am *MockAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
if am.GetNameServerGroupFunc != nil {
return am.GetNameServerGroupFunc(accountID, userID, nsGroupID)
return am.GetNameServerGroupFunc(ctx, accountID, userID, nsGroupID)
}
return nil, nil
}
// CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface
func (am *MockAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) {
func (am *MockAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) {
if am.CreateNameServerGroupFunc != nil {
return am.CreateNameServerGroupFunc(accountID, name, description, nameServerList, groups, primary, domains, enabled, userID, searchDomainsEnabled)
return am.CreateNameServerGroupFunc(ctx, accountID, name, description, nameServerList, groups, primary, domains, enabled, userID, searchDomainsEnabled)
}
return nil, nil
}
// SaveNameServerGroup mocks SaveNameServerGroup of the AccountManager interface
func (am *MockAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
func (am *MockAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
if am.SaveNameServerGroupFunc != nil {
return am.SaveNameServerGroupFunc(accountID, userID, nsGroupToSave)
return am.SaveNameServerGroupFunc(ctx, accountID, userID, nsGroupToSave)
}
return nil
}
// DeleteNameServerGroup mocks DeleteNameServerGroup of the AccountManager interface
func (am *MockAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
func (am *MockAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
if am.DeleteNameServerGroupFunc != nil {
return am.DeleteNameServerGroupFunc(accountID, nsGroupID, userID)
return am.DeleteNameServerGroupFunc(ctx, accountID, nsGroupID, userID)
}
return nil
}
// ListNameServerGroups mocks ListNameServerGroups of the AccountManager interface
func (am *MockAccountManager) ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
func (am *MockAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
if am.ListNameServerGroupsFunc != nil {
return am.ListNameServerGroupsFunc(accountID, userID)
return am.ListNameServerGroupsFunc(ctx, accountID, userID)
}
return nil, nil
}
// CreateUser mocks CreateUser of the AccountManager interface
func (am *MockAccountManager) CreateUser(accountID, userID string, invite *server.UserInfo) (*server.UserInfo, error) {
func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID string, invite *server.UserInfo) (*server.UserInfo, error) {
if am.CreateUserFunc != nil {
return am.CreateUserFunc(accountID, userID, invite)
return am.CreateUserFunc(ctx, accountID, userID, invite)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
}
// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface
func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User,
func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User,
error,
) {
if am.GetAccountFromTokenFunc != nil {
return am.GetAccountFromTokenFunc(claims)
return am.GetAccountFromTokenFunc(ctx, claims)
}
return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
}
func (am *MockAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error {
func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
if am.CheckUserAccessByJWTGroupsFunc != nil {
return am.CheckUserAccessByJWTGroupsFunc(claims)
return am.CheckUserAccessByJWTGroupsFunc(ctx, claims)
}
return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented")
}
// GetPeers mocks GetPeers of the AccountManager interface
func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) {
func (am *MockAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
if am.GetPeersFunc != nil {
return am.GetPeersFunc(accountID, userID)
return am.GetPeersFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeers is not implemented")
}
@ -595,57 +599,57 @@ func (am *MockAccountManager) GetDNSDomain() string {
}
// GetEvents mocks GetEvents of the AccountManager interface
func (am *MockAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) {
func (am *MockAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
if am.GetEventsFunc != nil {
return am.GetEventsFunc(accountID, userID)
return am.GetEventsFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetEvents is not implemented")
}
// GetDNSSettings mocks GetDNSSettings of the AccountManager interface
func (am *MockAccountManager) GetDNSSettings(accountID string, userID string) (*server.DNSSettings, error) {
func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) {
if am.GetDNSSettingsFunc != nil {
return am.GetDNSSettingsFunc(accountID, userID)
return am.GetDNSSettingsFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetDNSSettings is not implemented")
}
// SaveDNSSettings mocks SaveDNSSettings of the AccountManager interface
func (am *MockAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
func (am *MockAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
if am.SaveDNSSettingsFunc != nil {
return am.SaveDNSSettingsFunc(accountID, userID, dnsSettingsToSave)
return am.SaveDNSSettingsFunc(ctx, accountID, userID, dnsSettingsToSave)
}
return status.Errorf(codes.Unimplemented, "method SaveDNSSettings is not implemented")
}
// GetPeer mocks GetPeer of the AccountManager interface
func (am *MockAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) {
func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
if am.GetPeerFunc != nil {
return am.GetPeerFunc(accountID, peerID, userID)
return am.GetPeerFunc(ctx, accountID, peerID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeer is not implemented")
}
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
func (am *MockAccountManager) UpdateAccountSettings(accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
if am.UpdateAccountSettingsFunc != nil {
return am.UpdateAccountSettingsFunc(accountID, userID, newSettings)
return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings)
}
return nil, status.Errorf(codes.Unimplemented, "method UpdateAccountSettings is not implemented")
}
// LoginPeer mocks LoginPeer of the AccountManager interface
func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
if am.LoginPeerFunc != nil {
return am.LoginPeerFunc(login)
return am.LoginPeerFunc(ctx, login)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
}
// SyncPeer mocks SyncPeer of the AccountManager interface
func (am *MockAccountManager) SyncPeer(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
if am.SyncPeerFunc != nil {
return am.SyncPeerFunc(sync, account)
return am.SyncPeerFunc(ctx, sync, account)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
}
@ -667,9 +671,9 @@ func (am *MockAccountManager) HasConnectedChannel(peerID string) bool {
}
// StoreEvent mocks StoreEvent of the AccountManager interface
func (am *MockAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
func (am *MockAccountManager) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
if am.StoreEventFunc != nil {
am.StoreEventFunc(initiatorID, targetID, accountID, activityID, meta)
am.StoreEventFunc(ctx, initiatorID, targetID, accountID, activityID, meta)
}
}
@ -682,35 +686,35 @@ func (am *MockAccountManager) GetExternalCacheManager() server.ExternalCacheMana
}
// GetPostureChecks mocks GetPostureChecks of the AccountManager interface
func (am *MockAccountManager) GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) {
func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
if am.GetPostureChecksFunc != nil {
return am.GetPostureChecksFunc(accountID, postureChecksID, userID)
return am.GetPostureChecksFunc(ctx, accountID, postureChecksID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPostureChecks is not implemented")
}
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
func (am *MockAccountManager) SavePostureChecks(accountID, userID string, postureChecks *posture.Checks) error {
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
if am.SavePostureChecksFunc != nil {
return am.SavePostureChecksFunc(accountID, userID, postureChecks)
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
}
return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
}
// DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface
func (am *MockAccountManager) DeletePostureChecks(accountID, postureChecksID, userID string) error {
func (am *MockAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
if am.DeletePostureChecksFunc != nil {
return am.DeletePostureChecksFunc(accountID, postureChecksID, userID)
return am.DeletePostureChecksFunc(ctx, accountID, postureChecksID, userID)
}
return status.Errorf(codes.Unimplemented, "method DeletePostureChecks is not implemented")
}
// ListPostureChecks mocks ListPostureChecks of the AccountManager interface
func (am *MockAccountManager) ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) {
func (am *MockAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
if am.ListPostureChecksFunc != nil {
return am.ListPostureChecksFunc(accountID, userID)
return am.ListPostureChecksFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListPostureChecks is not implemented")
}
@ -724,25 +728,25 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager {
}
// UpdateIntegratedValidatorGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface
func (am *MockAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error {
func (am *MockAccountManager) UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error {
if am.UpdateIntegratedValidatorGroupsFunc != nil {
return am.UpdateIntegratedValidatorGroupsFunc(accountID, userID, groups)
return am.UpdateIntegratedValidatorGroupsFunc(ctx, accountID, userID, groups)
}
return status.Errorf(codes.Unimplemented, "method UpdateIntegratedValidatorGroups is not implemented")
}
// GroupValidation mocks GroupValidation of the AccountManager interface
func (am *MockAccountManager) GroupValidation(accountId string, groups []string) (bool, error) {
func (am *MockAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) {
if am.GroupValidationFunc != nil {
return am.GroupValidationFunc(accountId, groups)
return am.GroupValidationFunc(ctx, accountId, groups)
}
return false, status.Errorf(codes.Unimplemented, "method GroupValidation is not implemented")
}
// SyncPeerMeta mocks SyncPeerMeta of the AccountManager interface
func (am *MockAccountManager) SyncPeerMeta(peerPubKey string, meta nbpeer.PeerSystemMeta) error {
func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
if am.SyncPeerMetaFunc != nil {
return am.SyncPeerMetaFunc(peerPubKey, meta)
return am.SyncPeerMetaFunc(ctx, peerPubKey, meta)
}
return status.Errorf(codes.Unimplemented, "method SyncPeerMeta is not implemented")
}
@ -754,3 +758,11 @@ func (am *MockAccountManager) FindExistingPostureCheck(accountID string, checks
}
return nil, status.Errorf(codes.Unimplemented, "method FindExistingPostureCheck is not implemented")
}
// GetAccountIDForPeerKey mocks GetAccountIDForPeerKey of the AccountManager interface
func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) {
if am.GetAccountIDForPeerKeyFunc != nil {
return am.GetAccountIDForPeerKeyFunc(ctx, peerKey)
}
return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented")
}

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"errors"
"regexp"
"unicode/utf8"
@ -17,12 +18,12 @@ import (
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@ -45,12 +46,12 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID
}
// CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@ -79,29 +80,29 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
account.NameServerGroups[newNSGroup.ID] = newNSGroup
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
am.StoreEvent(userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
return newNSGroup.Copy(), nil
}
// SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
if nsGroupToSave == nil {
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
}
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
@ -114,25 +115,25 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return err
}
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
am.StoreEvent(userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
return nil
}
// DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
@ -144,25 +145,25 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, use
delete(account.NameServerGroups, nsGroupID)
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return err
}
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
am.StoreEvent(userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
return nil
}
// ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"net/netip"
"testing"
@ -383,6 +384,7 @@ func TestCreateNameServerGroup(t *testing.T) {
}
outNSGroup, err := am.CreateNameServerGroup(
context.Background(),
account.Id,
testCase.inputArgs.name,
testCase.inputArgs.description,
@ -611,7 +613,7 @@ func TestSaveNameServerGroup(t *testing.T) {
account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup
err = am.Store.SaveAccount(account)
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
t.Error("account should be saved")
}
@ -646,7 +648,7 @@ func TestSaveNameServerGroup(t *testing.T) {
}
}
err = am.SaveNameServerGroup(account.Id, userID, nsGroupToSave)
err = am.SaveNameServerGroup(context.Background(), account.Id, userID, nsGroupToSave)
testCase.errFunc(t, err)
@ -654,7 +656,7 @@ func TestSaveNameServerGroup(t *testing.T) {
return
}
account, err = am.Store.GetAccount(account.Id)
account, err = am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Fatal(err)
}
@ -705,17 +707,17 @@ func TestDeleteNameServerGroup(t *testing.T) {
account.NameServerGroups[testingNSGroup.ID] = testingNSGroup
err = am.Store.SaveAccount(account)
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
t.Error("failed to save account")
}
err = am.DeleteNameServerGroup(account.Id, testingNSGroup.ID, userID)
err = am.DeleteNameServerGroup(context.Background(), account.Id, testingNSGroup.ID, userID)
if err != nil {
t.Error("deleting nameserver group failed with error: ", err)
}
savedAccount, err := am.Store.GetAccount(account.Id)
savedAccount, err := am.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Error("failed to retrieve saved account with error: ", err)
}
@ -738,7 +740,7 @@ func TestGetNameServerGroup(t *testing.T) {
t.Error("failed to init testing account")
}
foundGroup, err := am.GetNameServerGroup(account.Id, testUserID, existingNSGroupID)
foundGroup, err := am.GetNameServerGroup(context.Background(), account.Id, testUserID, existingNSGroupID)
if err != nil {
t.Error("getting existing nameserver group failed with error: ", err)
}
@ -747,7 +749,7 @@ func TestGetNameServerGroup(t *testing.T) {
t.Error("got a nil group while getting nameserver group with ID")
}
_, err = am.GetNameServerGroup(account.Id, testUserID, "not existing")
_, err = am.GetNameServerGroup(context.Background(), account.Id, testUserID, "not existing")
if err == nil {
t.Error("getting not existing nameserver group should return error, got nil")
}
@ -760,13 +762,13 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
}
func createNSStore(t *testing.T) (Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromJson(dataDir)
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
if err != nil {
return nil, err
}
@ -829,7 +831,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
userID := testUserID
domain := "example.com"
account := newAccountWithId(accountID, userID, domain)
account := newAccountWithId(context.Background(), accountID, userID, domain)
account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup
@ -846,16 +848,16 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
account.Groups[newGroup1.ID] = newGroup1
account.Groups[newGroup2.ID] = newGroup2
err := am.Store.SaveAccount(account)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err
}
_, _, _, err = am.AddPeer("", userID, peer1)
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer1)
if err != nil {
return nil, err
}
_, _, _, err = am.AddPeer("", userID, peer2)
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer2)
if err != nil {
return nil, err
}

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"fmt"
"net"
"strings"
@ -45,8 +46,8 @@ type PeerLogin struct {
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) {
account, err := am.Store.GetAccount(accountID)
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@ -79,7 +80,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P
// fetch all the peers that have access to the user's peers
for _, peer := range peers {
aclPeers, _ := account.getPeerConnectionResources(peer.ID, approvedPeersMap)
aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
for _, p := range aclPeers {
peersMap[p.ID] = p
}
@ -94,7 +95,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P
}
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP, account *Account) error {
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error {
peer, err := account.FindPeerByPubKey(peerPubKey)
if err != nil {
return err
@ -113,7 +114,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
if am.geo != nil && realIP != nil {
location, err := am.geo.Lookup(realIP)
if err != nil {
log.Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
} else {
peer.Location.ConnectionIP = realIP
peer.Location.CountryCode = location.Country.ISOCode
@ -121,7 +122,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
peer.Location.GeoNameID = location.City.GeonameID
err = am.Store.SavePeerLocation(account.Id, peer)
if err != nil {
log.Warnf("could not store location for peer %s: %s", peer.ID, err)
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
}
}
}
@ -134,24 +135,24 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
}
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(account)
am.checkAndSchedulePeerLoginExpiration(ctx, account)
}
if oldStatus.LoginExpired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again.
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
}
return nil
}
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@ -161,7 +162,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb
return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID)
}
update, err = am.integratedPeerValidator.ValidatePeer(update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
update, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil {
return nil, err
}
@ -172,7 +173,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb
if !update.SSHEnabled {
event = activity.PeerSSHDisabled
}
am.StoreEvent(userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
}
if peer.Name != update.Name {
@ -187,7 +188,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb
peer.DNSLabel = newLabel
am.StoreEvent(userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain()))
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain()))
}
if peer.LoginExpirationEnabled != update.LoginExpirationEnabled {
@ -202,27 +203,27 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb
if !update.LoginExpirationEnabled {
event = activity.PeerLoginExpirationDisabled
}
am.StoreEvent(userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(account)
am.checkAndSchedulePeerLoginExpiration(ctx, account)
}
}
account.UpdatePeer(peer)
err = am.Store.SaveAccount(account)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
return peer, nil
}
// deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock
func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string, userID string) error {
func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Account, peerIDs []string, userID string) error {
// the first loop is needed to ensure all peers present under the account before modifying, otherwise
// we might have some inconsistencies
@ -239,13 +240,13 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string,
// the 2nd loop performs the actual modification
for _, peer := range peers {
err := am.integratedPeerValidator.PeerDeleted(account.Id, peer.ID)
err := am.integratedPeerValidator.PeerDeleted(ctx, account.Id, peer.ID)
if err != nil {
return err
}
account.DeletePeer(peer.ID)
am.peersUpdateManager.SendUpdate(peer.ID,
am.peersUpdateManager.SendUpdate(ctx, peer.ID,
&UpdateMessage{
Update: &proto.SyncResponse{
// fill those field for backward compatibility
@ -261,41 +262,41 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string,
},
},
})
am.peersUpdateManager.CloseChannel(peer.ID)
am.StoreEvent(userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
}
return nil
}
// DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
err = am.deletePeers(account, []string{peerID}, userID)
err = am.deletePeers(ctx, account, []string{peerID}, userID)
if err != nil {
return err
}
err = am.Store.SaveAccount(account)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return err
}
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
return nil
}
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, error) {
account, err := am.Store.GetAccountByPeerID(peerID)
func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) {
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
if err != nil {
return nil, err
}
@ -314,12 +315,12 @@ func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, erro
if err != nil {
return nil, err
}
return account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validatedPeers), nil
return account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validatedPeers), nil
}
// GetPeerNetwork returns the Network for a given peer
func (am *DefaultAccountManager) GetPeerNetwork(peerID string) (*Network, error) {
account, err := am.Store.GetAccountByPeerID(peerID)
func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID string) (*Network, error) {
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
if err != nil {
return nil, err
}
@ -334,7 +335,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(peerID string) (*Network, error)
// to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
// The peer property is just a placeholder for the Peer properties to pass further
func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
if setupKey == "" && userID == "" {
// no auth method provided => reject access
return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
@ -348,13 +349,13 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
addedByUser = true
accountID, err = am.Store.GetAccountIDByUserID(userID)
} else {
accountID, err = am.Store.GetAccountIDBySetupKey(setupKey)
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, setupKey)
}
if err != nil {
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
}
unlock := am.Store.AcquireAccountWriteLock(accountID)
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer func() {
if unlock != nil {
unlock()
@ -363,14 +364,14 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
var account *Account
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(accountID)
account, err = am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
if am.idpManager != nil {
userdata, err := am.lookupUserInCache(userID, account)
userdata, err := am.lookupUserInCache(ctx, userID, account)
if err == nil && userdata != nil {
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
}
@ -479,7 +480,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
}
}
newPeer = am.integratedPeerValidator.PreparePeer(account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra)
newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra)
if addedByUser {
user, err := account.FindUser(userID)
@ -491,7 +492,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
account.Peers[newPeer.ID] = newPeer
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, nil, nil, err
}
@ -506,9 +507,9 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
opEvent.Meta["setup_key_name"] = setupKeyName
}
am.StoreEvent(opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
approvedPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
@ -516,12 +517,12 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
}
postureChecks := am.getPeerPostureChecks(account, peer)
networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain, approvedPeersMap)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, am.dnsDomain, approvedPeersMap)
return newPeer, networkMap, postureChecks, nil
}
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
if err != nil {
return nil, nil, nil, status.NewPeerNotRegisteredError()
@ -532,23 +533,23 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp
return nil, nil, nil, err
}
if peerLoginExpired(peer, account.Settings) {
if peerLoginExpired(ctx, peer, account.Settings) {
return nil, nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
}
peer, updated := updatePeerMeta(peer, sync.Meta, account)
if updated {
err = am.Store.SaveAccount(account)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, nil, nil, err
}
if sync.UpdateAccountPeers {
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
}
}
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil {
return nil, nil, nil, err
}
@ -563,7 +564,7 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp
}
if isStatusChanged {
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
}
validPeersMap, err := am.GetValidatedPeers(account)
@ -572,13 +573,13 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp
}
postureChecks = am.getPeerPostureChecks(account, peer)
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validPeersMap), postureChecks, nil
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validPeersMap), postureChecks, nil
}
// LoginPeer logs in or registers a peer.
// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so.
func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
accountID, err := am.Store.GetAccountIDByPeerPubKey(login.WireGuardPubKey)
func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey)
if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
// we couldn't find this peer by its public key which can mean that peer hasn't been registered yet.
@ -591,7 +592,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
if am.geo != nil && login.ConnectionIP != nil {
location, err := am.geo.Lookup(login.ConnectionIP)
if err != nil {
log.Warnf("failed to get location for new peer realip: [%s]: %v", login.ConnectionIP.String(), err)
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", login.ConnectionIP.String(), err)
} else {
newPeer.Location.ConnectionIP = login.ConnectionIP
newPeer.Location.CountryCode = location.Country.ISOCode
@ -601,19 +602,19 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
}
}
return am.AddPeer(login.SetupKey, login.UserID, newPeer)
return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer)
}
log.Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err)
log.WithContext(ctx).Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err)
return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer")
}
peer, err := am.Store.GetPeerByPeerPubKey(login.WireGuardPubKey)
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
if err != nil {
return nil, nil, nil, status.NewPeerNotRegisteredError()
}
accSettings, err := am.Store.GetAccountSettings(accountID)
accSettings, err := am.Store.GetAccountSettings(ctx, accountID)
if err != nil {
return nil, nil, nil, status.Errorf(status.Internal, "failed to get account settings: %s", err)
}
@ -621,30 +622,30 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
var isWriteLock bool
// duplicated logic from after the lock to have an early exit
expired := peerLoginExpired(peer, accSettings)
expired := peerLoginExpired(ctx, peer, accSettings)
switch {
case expired:
if err := checkAuth(login.UserID, peer); err != nil {
if err := checkAuth(ctx, login.UserID, peer); err != nil {
return nil, nil, nil, err
}
isWriteLock = true
log.Debugf("peer login expired, acquiring write lock")
log.WithContext(ctx).Debugf("peer login expired, acquiring write lock")
case peer.UpdateMetaIfNew(login.Meta):
isWriteLock = true
log.Debugf("peer changed meta, acquiring write lock")
log.WithContext(ctx).Debugf("peer changed meta, acquiring write lock")
default:
isWriteLock = false
log.Debugf("peer meta is the same, acquiring read lock")
log.WithContext(ctx).Debugf("peer meta is the same, acquiring read lock")
}
var unlock func()
if isWriteLock {
unlock = am.Store.AcquireAccountWriteLock(accountID)
unlock = am.Store.AcquireAccountWriteLock(ctx, accountID)
} else {
unlock = am.Store.AcquireAccountReadLock(accountID)
unlock = am.Store.AcquireAccountReadLock(ctx, accountID)
}
defer func() {
if unlock != nil {
@ -653,7 +654,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
}()
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
@ -671,8 +672,8 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
// this flag prevents unnecessary calls to the persistent store.
shouldStoreAccount := false
updateRemotePeers := false
if peerLoginExpired(peer, account.Settings) {
err = checkAuth(login.UserID, peer)
if peerLoginExpired(ctx, peer, account.Settings) {
err = checkAuth(ctx, login.UserID, peer)
if err != nil {
return nil, nil, nil, err
}
@ -689,10 +690,10 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
}
user.updateLastLogin(peer.LastLogin)
am.StoreEvent(login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
}
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil {
return nil, nil, nil, err
}
@ -701,17 +702,17 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
shouldStoreAccount = true
}
peer, err = am.checkAndUpdatePeerSSHKey(peer, account, login.SSHKey)
peer, err = am.checkAndUpdatePeerSSHKey(ctx, peer, account, login.SSHKey)
if err != nil {
return nil, nil, nil, err
}
if shouldStoreAccount {
if !isWriteLock {
log.Errorf("account %s should be stored but is not write locked", accountID)
log.WithContext(ctx).Errorf("account %s should be stored but is not write locked", accountID)
return nil, nil, nil, status.Errorf(status.Internal, "account should be stored but is not write locked")
}
err = am.Store.SaveAccount(account)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, nil, nil, err
}
@ -720,7 +721,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
unlock = nil
if updateRemotePeers || isStatusChanged {
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
}
var postureChecks []*posture.Checks
@ -738,7 +739,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
}
postureChecks = am.getPeerPostureChecks(account, peer)
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil
}
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
@ -754,23 +755,23 @@ func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
return nil
}
func checkAuth(loginUserID string, peer *nbpeer.Peer) error {
func checkAuth(ctx context.Context, loginUserID string, peer *nbpeer.Peer) error {
if loginUserID == "" {
// absence of a user ID indicates that JWT wasn't provided.
return status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
}
if peer.UserID != loginUserID {
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID)
log.WithContext(ctx).Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID)
return status.Errorf(status.Unauthenticated, "can't login")
}
return nil
}
func peerLoginExpired(peer *nbpeer.Peer, settings *Settings) bool {
func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings) bool {
expired, expiresIn := peer.LoginExpired(settings.PeerLoginExpiration)
expired = settings.PeerLoginExpirationEnabled && expired
if expired || peer.Status.LoginExpired {
log.Debugf("peer's %s login expired %v ago", peer.ID, expiresIn)
log.WithContext(ctx).Debugf("peer's %s login expired %v ago", peer.ID, expiresIn)
return true
}
return false
@ -781,48 +782,48 @@ func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) {
account.UpdatePeer(peer)
}
func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(peer *nbpeer.Peer, account *Account, newSSHKey string) (*nbpeer.Peer, error) {
func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(ctx context.Context, peer *nbpeer.Peer, account *Account, newSSHKey string) (*nbpeer.Peer, error) {
if len(newSSHKey) == 0 {
log.Debugf("no new SSH key provided for peer %s, skipping update", peer.ID)
log.WithContext(ctx).Debugf("no new SSH key provided for peer %s, skipping update", peer.ID)
return peer, nil
}
if peer.SSHKey == newSSHKey {
log.Debugf("same SSH key provided for peer %s, skipping update", peer.ID)
log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peer.ID)
return peer, nil
}
peer.SSHKey = newSSHKey
account.UpdatePeer(peer)
err := am.Store.SaveAccount(account)
err := am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
// trigger network map update
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
return peer, nil
}
// UpdatePeerSSHKey updates peer's public SSH key
func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string) error {
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
if sshKey == "" {
log.Debugf("empty SSH key provided for peer %s, skipping update", peerID)
log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID)
return nil
}
account, err := am.Store.GetAccountByPeerID(peerID)
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
if err != nil {
return err
}
unlock := am.Store.AcquireAccountWriteLock(account.Id)
unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id)
defer unlock()
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
account, err = am.Store.GetAccount(account.Id)
account, err = am.Store.GetAccount(ctx, account.Id)
if err != nil {
return err
}
@ -833,30 +834,30 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string)
}
if peer.SSHKey == sshKey {
log.Debugf("same SSH key provided for peer %s, skipping update", peerID)
log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID)
return nil
}
peer.SSHKey = sshKey
account.UpdatePeer(peer)
err = am.Store.SaveAccount(account)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return err
}
// trigger network map update
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
return nil
}
// GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@ -893,7 +894,7 @@ func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbp
}
for _, p := range userPeers {
aclPeers, _ := account.getPeerConnectionResources(p.ID, approvedPeersMap)
aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap)
for _, aclPeer := range aclPeers {
if aclPeer.ID == peerID {
return peer, nil
@ -914,23 +915,23 @@ func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Acco
// updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
peers := account.GetPeers()
approvedPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
log.Errorf("failed send out updates to peers, failed to validate peer: %v", err)
log.WithContext(ctx).Errorf("failed send out updates to peers, failed to validate peer: %v", err)
return
}
for _, peer := range peers {
if !am.peersUpdateManager.HasChannel(peer.ID) {
log.Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
continue
}
postureChecks := am.getPeerPostureChecks(account, peer)
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks)
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap)
update := toSyncResponse(ctx, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks)
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
}
}

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"testing"
"time"
@ -80,7 +81,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
t.Fatal(err)
}
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false)
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false)
if err != nil {
t.Fatal("error creating setup key")
return
@ -92,7 +93,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
return
}
peer1, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
})
@ -106,7 +107,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
t.Fatal(err)
return
}
_, _, _, err = manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
_, _, _, err = manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
@ -116,7 +117,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
return
}
networkMap, err := manager.GetNetworkMap(peer1.ID)
networkMap, err := manager.GetNetworkMap(context.Background(), peer1.ID)
if err != nil {
t.Fatal(err)
return
@ -165,7 +166,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
return
}
peer1, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
})
@ -179,7 +180,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
t.Fatal(err)
return
}
peer2, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer2, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
@ -188,13 +189,13 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
return
}
policies, err := manager.ListPolicies(account.Id, userID)
policies, err := manager.ListPolicies(context.Background(), account.Id, userID)
if err != nil {
t.Errorf("expecting to get a list of rules, got failure %v", err)
return
}
err = manager.DeletePolicy(account.Id, policies[0].ID, userID)
err = manager.DeletePolicy(context.Background(), account.Id, policies[0].ID, userID)
if err != nil {
t.Errorf("expecting to delete 1 group, got failure %v", err)
return
@ -213,12 +214,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
group1.Peers = append(group1.Peers, peer1.ID)
group2.Peers = append(group2.Peers, peer2.ID)
err = manager.SaveGroup(account.Id, userID, &group1)
err = manager.SaveGroup(context.Background(), account.Id, userID, &group1)
if err != nil {
t.Errorf("expecting group1 to be added, got failure %v", err)
return
}
err = manager.SaveGroup(account.Id, userID, &group2)
err = manager.SaveGroup(context.Background(), account.Id, userID, &group2)
if err != nil {
t.Errorf("expecting group2 to be added, got failure %v", err)
return
@ -235,13 +236,13 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
}
err = manager.SavePolicy(account.Id, userID, &policy)
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return
}
networkMap1, err := manager.GetNetworkMap(peer1.ID)
networkMap1, err := manager.GetNetworkMap(context.Background(), peer1.ID)
if err != nil {
t.Fatal(err)
return
@ -264,7 +265,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
)
}
networkMap2, err := manager.GetNetworkMap(peer2.ID)
networkMap2, err := manager.GetNetworkMap(context.Background(), peer2.ID)
if err != nil {
t.Fatal(err)
return
@ -283,13 +284,13 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
}
policy.Enabled = false
err = manager.SavePolicy(account.Id, userID, &policy)
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return
}
networkMap1, err = manager.GetNetworkMap(peer1.ID)
networkMap1, err = manager.GetNetworkMap(context.Background(), peer1.ID)
if err != nil {
t.Fatal(err)
return
@ -304,7 +305,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
return
}
networkMap2, err = manager.GetNetworkMap(peer2.ID)
networkMap2, err = manager.GetNetworkMap(context.Background(), peer2.ID)
if err != nil {
t.Fatal(err)
return
@ -329,7 +330,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
t.Fatal(err)
}
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false)
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false)
if err != nil {
t.Fatal("error creating setup key")
return
@ -341,7 +342,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
return
}
peer1, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
})
@ -355,7 +356,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
t.Fatal(err)
return
}
_, _, _, err = manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
_, _, _, err = manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
@ -365,7 +366,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
return
}
network, err := manager.GetPeerNetwork(peer1.ID)
network, err := manager.GetPeerNetwork(context.Background(), peer1.ID)
if err != nil {
t.Fatal(err)
return
@ -387,21 +388,21 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
accountID := "test_account"
adminUser := "account_creator"
someUser := "some_user"
account := newAccountWithId(accountID, adminUser, "")
account := newAccountWithId(context.Background(), accountID, adminUser, "")
account.Users[someUser] = &User{
Id: someUser,
Role: UserRoleUser,
}
account.Settings.RegularUsersViewBlocked = false
err = manager.Store.SaveAccount(account)
err = manager.Store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatal(err)
return
}
// two peers one added by a regular user and one with a setup key
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false)
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false)
if err != nil {
t.Fatal("error creating setup key")
return
@ -413,7 +414,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
return
}
peer1, _, _, err := manager.AddPeer("", someUser, &nbpeer.Peer{
peer1, _, _, err := manager.AddPeer(context.Background(), "", someUser, &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
@ -429,7 +430,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
}
// the second peer added with a setup key
peer2, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
peer2, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
@ -439,7 +440,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
}
// the user can see its own peer
peer, err := manager.GetPeer(accountID, peer1.ID, someUser)
peer, err := manager.GetPeer(context.Background(), accountID, peer1.ID, someUser)
if err != nil {
t.Fatal(err)
return
@ -447,7 +448,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
assert.NotNil(t, peer)
// the user can see peer2 because peer1 of the user has access to peer2 due to the All group and the default rule 0 all-to-all access
peer, err = manager.GetPeer(accountID, peer2.ID, someUser)
peer, err = manager.GetPeer(context.Background(), accountID, peer2.ID, someUser)
if err != nil {
t.Fatal(err)
return
@ -456,7 +457,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
// delete the all-to-all policy so that user's peer1 has no access to peer2
for _, policy := range account.Policies {
err = manager.DeletePolicy(accountID, policy.ID, adminUser)
err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser)
if err != nil {
t.Fatal(err)
return
@ -464,18 +465,18 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
}
// at this point the user can't see the details of peer2
peer, err = manager.GetPeer(accountID, peer2.ID, someUser) //nolint
peer, err = manager.GetPeer(context.Background(), accountID, peer2.ID, someUser) //nolint
assert.Error(t, err)
// admin users can always access all the peers
peer, err = manager.GetPeer(accountID, peer1.ID, adminUser)
peer, err = manager.GetPeer(context.Background(), accountID, peer1.ID, adminUser)
if err != nil {
t.Fatal(err)
return
}
assert.NotNil(t, peer)
peer, err = manager.GetPeer(accountID, peer2.ID, adminUser)
peer, err = manager.GetPeer(context.Background(), accountID, peer2.ID, adminUser)
if err != nil {
t.Fatal(err)
return
@ -574,7 +575,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
accountID := "test_account"
adminUser := "account_creator"
someUser := "some_user"
account := newAccountWithId(accountID, adminUser, "")
account := newAccountWithId(context.Background(), accountID, adminUser, "")
account.Users[someUser] = &User{
Id: someUser,
Role: testCase.role,
@ -583,7 +584,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
account.Policies = []*Policy{}
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
err = manager.Store.SaveAccount(account)
err = manager.Store.SaveAccount(context.Background(), account)
if err != nil {
t.Fatal(err)
return
@ -601,7 +602,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
return
}
_, _, _, err = manager.AddPeer("", someUser, &nbpeer.Peer{
_, _, _, err = manager.AddPeer(context.Background(), "", someUser, &nbpeer.Peer{
Key: peerKey1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
})
@ -610,7 +611,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
return
}
_, _, _, err = manager.AddPeer("", adminUser, &nbpeer.Peer{
_, _, _, err = manager.AddPeer(context.Background(), "", adminUser, &nbpeer.Peer{
Key: peerKey2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
})
@ -619,7 +620,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
return
}
peers, err := manager.GetPeers(accountID, someUser)
peers, err := manager.GetPeers(context.Background(), accountID, someUser)
if err != nil {
t.Fatal(err)
return

View File

@ -1,6 +1,7 @@
package server
import (
"context"
_ "embed"
"strconv"
"strings"
@ -211,9 +212,9 @@ type FirewallRule struct {
// getPeerConnectionResources for a given peer
//
// This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator()
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
for _, policy := range a.Policies {
if !policy.Enabled {
continue
@ -224,8 +225,8 @@ func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap ma
continue
}
sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID, nil, validatedPeersMap)
sourcePeers, peerInSources := getAllPeersFromGroups(ctx, a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := getAllPeersFromGroups(ctx, a, rule.Destinations, peerID, nil, validatedPeersMap)
if rule.Bidirectional {
if peerInSources {
@ -254,7 +255,7 @@ func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap ma
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be
// generated. The accumulator function returns the result of all the generator calls.
func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
rulesExists := make(map[string]struct{})
peersExists := make(map[string]struct{})
rules := make([]*FirewallRule, 0)
@ -262,7 +263,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in
all, err := a.GetGroupAll()
if err != nil {
log.Errorf("failed to get group all: %v", err)
log.WithContext(ctx).Errorf("failed to get group all: %v", err)
all = &nbgroup.Group{}
}
@ -313,11 +314,11 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in
}
// GetPolicy from the store
func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (*Policy, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@ -341,11 +342,11 @@ func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (
}
// SavePolicy in the store
func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Policy) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
@ -353,7 +354,7 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po
exists := am.savePolicy(account, policy)
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
@ -361,19 +362,19 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po
if exists {
action = activity.PolicyUpdated
}
am.StoreEvent(userID, policy.ID, accountID, action, policy.EventMeta())
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
return nil
}
// DeletePolicy from the store
func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
@ -384,23 +385,23 @@ func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(account); err != nil {
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
am.updateAccountPeers(account)
am.updateAccountPeers(ctx, account)
return nil
}
// ListPolicies from the store
func (am *DefaultAccountManager) ListPolicies(accountID, userID string) ([]*Policy, error) {
unlock := am.Store.AcquireAccountWriteLock(accountID)
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(accountID)
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
@ -490,7 +491,7 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
//
// Important: Posture checks are applicable only to source group peers,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func getAllPeersFromGroups(account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
func getAllPeersFromGroups(ctx context.Context, account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
peerInGroups := false
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
for _, g := range groups {
@ -506,7 +507,7 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string, sou
}
// validate the peer based on policy posture checks applied
isValid := account.validatePostureChecksOnPeer(sourcePostureChecksIDs, peer.ID)
isValid := account.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid {
continue
}
@ -527,7 +528,7 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string, sou
}
// validatePostureChecksOnPeer validates the posture checks on a peer
func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, peerID string) bool {
func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool {
peer, ok := a.Peers[peerID]
if !ok && peer == nil {
return false
@ -540,9 +541,9 @@ func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, pe
}
for _, check := range postureChecks.GetChecks() {
isValid, err := check.Check(*peer)
isValid, err := check.Check(ctx, *peer)
if err != nil {
log.Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error())
log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error())
}
if !isValid {
return false

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"fmt"
"net"
"testing"
@ -143,14 +144,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
t.Run("check that all peers get map", func(t *testing.T) {
for _, p := range account.Peers {
peers, firewallRules := account.getPeerConnectionResources(p.ID, validatedPeers)
peers, firewallRules := account.getPeerConnectionResources(context.Background(), p.ID, validatedPeers)
assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present")
assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present")
}
})
t.Run("check first peer map details", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources("peerB", validatedPeers)
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", validatedPeers)
assert.Len(t, peers, 7)
assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"])
@ -386,7 +387,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
}
t.Run("check first peer map", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers)
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*FirewallRule{
@ -414,7 +415,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
})
t.Run("check second peer map", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources("peerC", approvedPeers)
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*FirewallRule{
@ -444,7 +445,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
account.Policies[1].Rules[0].Bidirectional = false
t.Run("check first peer map directional only", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers)
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*FirewallRule{
@ -465,7 +466,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
})
t.Run("check second peer map directional only", func(t *testing.T) {
peers, firewallRules := account.getPeerConnectionResources("peerC", approvedPeers)
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*FirewallRule{
@ -662,7 +663,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
// will establish a connection with all source peers satisfying the NB posture check.
peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers)
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@ -672,7 +673,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.getPeerConnectionResources("peerC", approvedPeers)
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 1)
expectedFirewallRules := []*FirewallRule{
@ -688,7 +689,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.getPeerConnectionResources("peerE", approvedPeers)
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers)
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@ -698,7 +699,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.getPeerConnectionResources("peerI", approvedPeers)
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers)
assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"])
@ -713,19 +714,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group
peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers)
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0)
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group
peers, firewallRules = account.getPeerConnectionResources("peerI", approvedPeers)
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers)
assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0)
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.getPeerConnectionResources("peerC", approvedPeers)
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
@ -740,14 +741,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.getPeerConnectionResources("peerE", approvedPeers)
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers)
assert.Len(t, peers, 3)
assert.Len(t, firewallRules, 3)
assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"])
assert.Contains(t, peers, account.Peers["peerD"])
peers, firewallRules = account.getPeerConnectionResources("peerA", approvedPeers)
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerA", approvedPeers)
assert.Len(t, peers, 5)
// assert peers from Group Swarm
assert.Contains(t, peers, account.Peers["peerD"])

View File

@ -1,6 +1,7 @@
package posture
import (
"context"
"errors"
"net/netip"
"regexp"
@ -31,7 +32,7 @@ var (
// Check represents an interface for performing a check on a peer.
type Check interface {
Name() string
Check(peer nbpeer.Peer) (bool, error)
Check(ctx context.Context, peer nbpeer.Peer) (bool, error)
Validate() error
}

View File

@ -1,6 +1,7 @@
package posture
import (
"context"
"fmt"
"slices"
@ -25,7 +26,7 @@ type GeoLocationCheck struct {
Action string
}
func (g *GeoLocationCheck) Check(peer nbpeer.Peer) (bool, error) {
func (g *GeoLocationCheck) Check(_ context.Context, peer nbpeer.Peer) (bool, error) {
// deny if the peer location is not evaluated
if peer.Location.CountryCode == "" && peer.Location.CityName == "" {
return false, fmt.Errorf("peer's location is not set")

View File

@ -1,6 +1,7 @@
package posture
import (
"context"
"testing"
"github.com/netbirdio/netbird/management/server/peer"
@ -226,7 +227,7 @@ func TestGeoLocationCheck_Check(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid, err := tt.check.Check(tt.input)
isValid, err := tt.check.Check(context.Background(), tt.input)
if tt.wantErr {
assert.Error(t, err)
} else {

View File

@ -1,6 +1,7 @@
package posture
import (
"context"
"fmt"
"github.com/hashicorp/go-version"
@ -15,7 +16,7 @@ type NBVersionCheck struct {
var _ Check = (*NBVersionCheck)(nil)
func (n *NBVersionCheck) Check(peer nbpeer.Peer) (bool, error) {
func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) {
peerNBVersion, err := version.NewVersion(peer.Meta.WtVersion)
if err != nil {
return false, err
@ -30,7 +31,7 @@ func (n *NBVersionCheck) Check(peer nbpeer.Peer) (bool, error) {
return true, nil
}
log.Debugf("peer %s NB version %s is older than minimum allowed version %s",
log.WithContext(ctx).Debugf("peer %s NB version %s is older than minimum allowed version %s",
peer.ID, peer.Meta.WtVersion, n.MinVersion)
return false, nil

View File

@ -1,6 +1,7 @@
package posture
import (
"context"
"testing"
"github.com/netbirdio/netbird/management/server/peer"
@ -98,7 +99,7 @@ func TestNBVersionCheck_Check(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid, err := tt.check.Check(tt.input)
isValid, err := tt.check.Check(context.Background(), tt.input)
if tt.wantErr {
assert.Error(t, err)
} else {

View File

@ -1,6 +1,7 @@
package posture
import (
"context"
"fmt"
"net/netip"
"slices"
@ -16,7 +17,7 @@ type PeerNetworkRangeCheck struct {
var _ Check = (*PeerNetworkRangeCheck)(nil)
func (p *PeerNetworkRangeCheck) Check(peer nbpeer.Peer) (bool, error) {
func (p *PeerNetworkRangeCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) {
if len(peer.Meta.NetworkAddresses) == 0 {
return false, fmt.Errorf("peer's does not contain peer network range addresses")
}

View File

@ -1,6 +1,7 @@
package posture
import (
"context"
"net/netip"
"testing"
@ -137,7 +138,7 @@ func TestPeerNetworkRangeCheck_Check(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid, err := tt.check.Check(tt.peer)
isValid, err := tt.check.Check(context.Background(), tt.peer)
if tt.wantErr {
assert.Error(t, err)
} else {

View File

@ -1,6 +1,7 @@
package posture
import (
"context"
"fmt"
"strings"
@ -28,20 +29,20 @@ type OSVersionCheck struct {
var _ Check = (*OSVersionCheck)(nil)
func (c *OSVersionCheck) Check(peer nbpeer.Peer) (bool, error) {
func (c *OSVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) {
peerGoOS := peer.Meta.GoOS
switch peerGoOS {
case "android":
return checkMinVersion(peerGoOS, peer.Meta.OSVersion, c.Android)
return checkMinVersion(ctx, peerGoOS, peer.Meta.OSVersion, c.Android)
case "darwin":
return checkMinVersion(peerGoOS, peer.Meta.OSVersion, c.Darwin)
return checkMinVersion(ctx, peerGoOS, peer.Meta.OSVersion, c.Darwin)
case "ios":
return checkMinVersion(peerGoOS, peer.Meta.OSVersion, c.Ios)
return checkMinVersion(ctx, peerGoOS, peer.Meta.OSVersion, c.Ios)
case "linux":
kernelVersion := strings.Split(peer.Meta.KernelVersion, "-")[0]
return checkMinKernelVersion(peerGoOS, kernelVersion, c.Linux)
return checkMinKernelVersion(ctx, peerGoOS, kernelVersion, c.Linux)
case "windows":
return checkMinKernelVersion(peerGoOS, peer.Meta.KernelVersion, c.Windows)
return checkMinKernelVersion(ctx, peerGoOS, peer.Meta.KernelVersion, c.Windows)
}
return true, nil
}
@ -79,9 +80,9 @@ func (c *OSVersionCheck) Validate() error {
return nil
}
func checkMinVersion(peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) {
func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) {
if check == nil {
log.Debugf("peer %s OS is not allowed in the check", peerGoOS)
log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS)
return false, nil
}
@ -99,14 +100,14 @@ func checkMinVersion(peerGoOS, peerVersion string, check *MinVersionCheck) (bool
return true, nil
}
log.Debugf("peer %s OS version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinVersion)
log.WithContext(ctx).Debugf("peer %s OS version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinVersion)
return false, nil
}
func checkMinKernelVersion(peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) {
func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) {
if check == nil {
log.Debugf("peer %s OS is not allowed in the check", peerGoOS)
log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS)
return false, nil
}
@ -124,7 +125,7 @@ func checkMinKernelVersion(peerGoOS, peerVersion string, check *MinKernelVersion
return true, nil
}
log.Debugf("peer %s kernel version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinKernelVersion)
log.WithContext(ctx).Debugf("peer %s kernel version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinKernelVersion)
return false, nil
}

Some files were not shown because too many files have changed in this diff Show More