mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-07 08:44:07 +01:00
Add context to throughout the project and update logging (#2209)
propagate context from all the API calls and log request ID, account ID and peer ID --------- Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
This commit is contained in:
parent
7cb81f1d70
commit
765aba2c1c
@ -76,7 +76,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -87,13 +87,13 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
iv, _ := integrations.NewIntegratedValidator(eventStore)
|
||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
|
||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -174,7 +174,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
//time.Sleep(250 * time.Millisecond)
|
||||
// time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
|
||||
|
||||
@ -1057,7 +1057,7 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
|
||||
store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir)
|
||||
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@ -1068,13 +1068,13 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
ia, _ := integrations.NewIntegratedValidator(eventStore)
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
@ -108,7 +108,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, cleanUp, err := server.NewTestStoreFromJson(config.Datadir)
|
||||
store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@ -119,13 +119,13 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
ia, _ := integrations.NewIntegratedValidator(eventStore)
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
@ -7,6 +7,18 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/context"
|
||||
)
|
||||
|
||||
type ExecutionContext string
|
||||
|
||||
const (
|
||||
ExecutionContextKey = "executionContext"
|
||||
|
||||
HTTPSource ExecutionContext = "HTTP"
|
||||
GRPCSource ExecutionContext = "GRPC"
|
||||
SystemSource ExecutionContext = "SYSTEM"
|
||||
)
|
||||
|
||||
// ContextHook is a custom hook for add the source information for the entry
|
||||
@ -30,6 +42,27 @@ func (hook ContextHook) Levels() []logrus.Level {
|
||||
func (hook ContextHook) Fire(entry *logrus.Entry) error {
|
||||
src := hook.parseSrc(entry.Caller.File)
|
||||
entry.Data["source"] = fmt.Sprintf("%s:%v", src, entry.Caller.Line)
|
||||
|
||||
if entry.Context == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
source, ok := entry.Context.Value(ExecutionContextKey).(ExecutionContext)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
entry.Data["context"] = source
|
||||
|
||||
switch source {
|
||||
case HTTPSource:
|
||||
addHTTPFields(entry)
|
||||
case GRPCSource:
|
||||
addGRPCFields(entry)
|
||||
case SystemSource:
|
||||
addSystemFields(entry)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -59,3 +92,42 @@ func (hook ContextHook) parseSrc(filePath string) string {
|
||||
file := path.Base(filePath)
|
||||
return fmt.Sprintf("%s/%s", pkg, file)
|
||||
}
|
||||
|
||||
func addHTTPFields(entry *logrus.Entry) {
|
||||
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
|
||||
entry.Data[context.RequestIDKey] = ctxReqID
|
||||
}
|
||||
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
|
||||
entry.Data[context.AccountIDKey] = ctxAccountID
|
||||
}
|
||||
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
|
||||
entry.Data[context.UserIDKey] = ctxInitiatorID
|
||||
}
|
||||
}
|
||||
|
||||
func addGRPCFields(entry *logrus.Entry) {
|
||||
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
|
||||
entry.Data[context.RequestIDKey] = ctxReqID
|
||||
}
|
||||
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
|
||||
entry.Data[context.AccountIDKey] = ctxAccountID
|
||||
}
|
||||
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
|
||||
entry.Data[context.PeerIDKey] = ctxDeviceID
|
||||
}
|
||||
}
|
||||
|
||||
func addSystemFields(entry *logrus.Entry) {
|
||||
if ctxReqID, ok := entry.Context.Value(context.RequestIDKey).(string); ok {
|
||||
entry.Data[context.RequestIDKey] = ctxReqID
|
||||
}
|
||||
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
|
||||
entry.Data[context.UserIDKey] = ctxInitiatorID
|
||||
}
|
||||
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
|
||||
entry.Data[context.AccountIDKey] = ctxAccountID
|
||||
}
|
||||
if ctxDeviceID, ok := entry.Context.Value(context.PeerIDKey).(string); ok {
|
||||
entry.Data[context.PeerIDKey] = ctxDeviceID
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
package formatter
|
||||
|
||||
import "github.com/sirupsen/logrus"
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// SetTextFormatter set the text formatter for given logger.
|
||||
func SetTextFormatter(logger *logrus.Logger) {
|
||||
@ -9,6 +11,13 @@ func SetTextFormatter(logger *logrus.Logger) {
|
||||
logger.AddHook(NewContextHook())
|
||||
}
|
||||
|
||||
// SetJSONFormatter set the JSON formatter for given logger.
|
||||
func SetJSONFormatter(logger *logrus.Logger) {
|
||||
logger.Formatter = &logrus.JSONFormatter{}
|
||||
logger.ReportCaller = true
|
||||
logger.AddHook(NewContextHook())
|
||||
}
|
||||
|
||||
// SetLogcatFormatter set the logcat formatter for given logger.
|
||||
func SetLogcatFormatter(logger *logrus.Logger) {
|
||||
logger.Formatter = NewLogcatFormatter()
|
||||
|
3
go.mod
3
go.mod
@ -44,7 +44,6 @@ require (
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/google/martian/v3 v3.0.0
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
|
||||
github.com/gopacket/gopacket v1.1.1
|
||||
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
||||
@ -58,7 +57,7 @@ require (
|
||||
github.com/miekg/dns v1.1.43
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
|
7
go.sum
7
go.sum
@ -209,8 +209,6 @@ github.com/google/gofuzz v0.0.0-20161122191042-44d81051d367/go.mod h1:HP5RmnzzSN
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||
github.com/google/martian/v3 v3.0.0 h1:pMen7vLs8nvgEYhywH3KDWJIJTeEr2ULsVWHWYHQyBs=
|
||||
github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0=
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A=
|
||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc=
|
||||
github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o=
|
||||
@ -335,8 +333,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd h1:IzGGIJMpz07aPs3R6/4sxZv63JoCMddftLpVodUK+Ec=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240524104853-69c6d89826cd/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
|
||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
|
||||
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
|
||||
@ -565,7 +563,6 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20191004110552-13f9640d40b9/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20191105084925-a882066a44e0/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
|
@ -62,7 +62,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(config.Datadir)
|
||||
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -70,13 +70,13 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
ia, _ := integrations.NewIntegratedValidator(eventStore)
|
||||
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
@ -35,8 +36,10 @@ import (
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
httpapi "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
@ -77,6 +80,10 @@ var (
|
||||
Short: "start NetBird Management Server",
|
||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
flag.Parse()
|
||||
|
||||
//nolint
|
||||
ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource)
|
||||
|
||||
err := util.InitLog(logLevel, logFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
@ -85,7 +92,7 @@ var (
|
||||
// detect whether user specified a port
|
||||
userPort := cmd.Flag("port").Changed
|
||||
|
||||
config, err = loadMgmtConfig(mgmtConfig)
|
||||
config, err = loadMgmtConfig(ctx, mgmtConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
|
||||
}
|
||||
@ -116,6 +123,11 @@ var (
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, formatter.ExecutionContextKey, formatter.SystemSource)
|
||||
|
||||
err := handleRebrand(cmd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to migrate files %v", err)
|
||||
@ -131,11 +143,11 @@ var (
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = appMetrics.Expose(mgmtMetricsPort, "/metrics")
|
||||
err = appMetrics.Expose(ctx, mgmtMetricsPort, "/metrics")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
store, err := server.NewStore(config.StoreConfig.Engine, config.Datadir, appMetrics)
|
||||
store, err := server.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err)
|
||||
}
|
||||
@ -143,7 +155,7 @@ var (
|
||||
|
||||
var idpManager idp.Manager
|
||||
if config.IdpManagerConfig != nil {
|
||||
idpManager, err = idp.NewManager(*config.IdpManagerConfig, appMetrics)
|
||||
idpManager, err = idp.NewManager(ctx, *config.IdpManagerConfig, appMetrics)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed retrieving a new idp manager with err: %v", err)
|
||||
}
|
||||
@ -152,32 +164,32 @@ var (
|
||||
if disableSingleAccMode {
|
||||
mgmtSingleAccModeDomain = ""
|
||||
}
|
||||
eventStore, key, err := integrations.InitEventStore(config.Datadir, config.DataStoreEncryptionKey)
|
||||
eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize database: %s", err)
|
||||
}
|
||||
|
||||
if config.DataStoreEncryptionKey != key {
|
||||
log.Infof("update config with activity store key")
|
||||
log.WithContext(ctx).Infof("update config with activity store key")
|
||||
config.DataStoreEncryptionKey = key
|
||||
err := updateMgmtConfig(mgmtConfig, config)
|
||||
err := updateMgmtConfig(ctx, mgmtConfig, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write out store encryption key: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
geo, err := geolocation.NewGeolocation(config.Datadir)
|
||||
geo, err := geolocation.NewGeolocation(ctx, config.Datadir)
|
||||
if err != nil {
|
||||
log.Warnf("could not initialize geo location service: %v, we proceed without geo support", err)
|
||||
log.WithContext(ctx).Warnf("could not initialize geo location service: %v, we proceed without geo support", err)
|
||||
} else {
|
||||
log.Infof("geo location service has been initialized from %s", config.Datadir)
|
||||
log.WithContext(ctx).Infof("geo location service has been initialized from %s", config.Datadir)
|
||||
}
|
||||
|
||||
integratedPeerValidator, err := integrations.NewIntegratedValidator(eventStore)
|
||||
integratedPeerValidator, err := integrations.NewIntegratedValidator(ctx, eventStore)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
|
||||
}
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
||||
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build default manager: %v", err)
|
||||
@ -188,13 +200,13 @@ var (
|
||||
trustedPeers := config.ReverseProxy.TrustedPeers
|
||||
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
|
||||
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
|
||||
log.Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
|
||||
log.WithContext(ctx).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
|
||||
trustedPeers = defaultTrustedPeers
|
||||
}
|
||||
trustedHTTPProxies := config.ReverseProxy.TrustedHTTPProxies
|
||||
trustedProxiesCount := config.ReverseProxy.TrustedHTTPProxiesCount
|
||||
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
|
||||
log.Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
|
||||
log.WithContext(ctx).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
|
||||
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
|
||||
}
|
||||
realipOpts := []realip.Option{
|
||||
@ -206,8 +218,8 @@ var (
|
||||
gRPCOpts := []grpc.ServerOption{
|
||||
grpc.KeepaliveEnforcementPolicy(kaep),
|
||||
grpc.KeepaliveParams(kasp),
|
||||
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...)),
|
||||
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...)),
|
||||
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor),
|
||||
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
|
||||
}
|
||||
|
||||
var certManager *autocert.Manager
|
||||
@ -224,7 +236,7 @@ var (
|
||||
} else if config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "" {
|
||||
tlsConfig, err = loadTLSConfig(config.HttpConfig.CertFile, config.HttpConfig.CertKey)
|
||||
if err != nil {
|
||||
log.Errorf("cannot load TLS credentials: %v", err)
|
||||
log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err)
|
||||
return err
|
||||
}
|
||||
transportCredentials := credentials.NewTLS(tlsConfig)
|
||||
@ -233,6 +245,7 @@ var (
|
||||
}
|
||||
|
||||
jwtValidator, err := jwtclaims.NewJWTValidator(
|
||||
ctx,
|
||||
config.HttpConfig.AuthIssuer,
|
||||
config.GetAuthAudiences(),
|
||||
config.HttpConfig.AuthKeysLocation,
|
||||
@ -249,26 +262,24 @@ var (
|
||||
KeysLocation: config.HttpConfig.AuthKeysLocation,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating HTTP API handler: %v", err)
|
||||
}
|
||||
|
||||
ephemeralManager := server.NewEphemeralManager(store, accountManager)
|
||||
ephemeralManager.LoadInitialPeers()
|
||||
ephemeralManager.LoadInitialPeers(ctx)
|
||||
|
||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||
srv, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager)
|
||||
srv, err := server.NewServer(ctx, config, accountManager, peersUpdateManager, turnManager, appMetrics, ephemeralManager)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating gRPC API handler: %v", err)
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
|
||||
|
||||
installationID, err := getInstallationID(store)
|
||||
installationID, err := getInstallationID(ctx, store)
|
||||
if err != nil {
|
||||
log.Errorf("cannot load TLS credentials: %v", err)
|
||||
log.WithContext(ctx).Errorf("cannot load TLS credentials: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -278,18 +289,18 @@ var (
|
||||
idpManager = config.IdpManagerConfig.ManagerType
|
||||
}
|
||||
metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager, idpManager)
|
||||
go metricsWorker.Run()
|
||||
go metricsWorker.Run(ctx)
|
||||
}
|
||||
|
||||
var compatListener net.Listener
|
||||
if mgmtPort != ManagementLegacyPort {
|
||||
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
|
||||
// are using port 33073. For compatibility purposes we keep running a 2nd gRPC server on port 33073.
|
||||
compatListener, err = serveGRPC(gRPCAPIHandler, ManagementLegacyPort)
|
||||
compatListener, err = serveGRPC(ctx, gRPCAPIHandler, ManagementLegacyPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
||||
log.WithContext(ctx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
||||
}
|
||||
|
||||
rootHandler := handlerFunc(gRPCAPIHandler, httpAPIHandler)
|
||||
@ -306,8 +317,8 @@ var (
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating TLS listener on port %d: %v", mgmtPort, err)
|
||||
}
|
||||
log.Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String())
|
||||
serveHTTP(cml, certManager.HTTPHandler(nil))
|
||||
log.WithContext(ctx).Infof("running HTTP server (LetsEncrypt challenge handler): %s", cml.Addr().String())
|
||||
serveHTTP(ctx, cml, certManager.HTTPHandler(nil))
|
||||
}
|
||||
} else if tlsConfig != nil {
|
||||
listener, err = tls.Listen("tcp", fmt.Sprintf(":%d", mgmtPort), tlsConfig)
|
||||
@ -321,14 +332,14 @@ var (
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("management server version %s", version.NetbirdVersion())
|
||||
log.Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String())
|
||||
serveGRPCWithHTTP(listener, rootHandler, tlsEnabled)
|
||||
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
|
||||
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String())
|
||||
serveGRPCWithHTTP(ctx, listener, rootHandler, tlsEnabled)
|
||||
|
||||
SetupCloseHandler()
|
||||
|
||||
<-stopCh
|
||||
integratedPeerValidator.Stop()
|
||||
integratedPeerValidator.Stop(ctx)
|
||||
if geo != nil {
|
||||
_ = geo.Stop()
|
||||
}
|
||||
@ -339,39 +350,68 @@ var (
|
||||
_ = certManager.Listener().Close()
|
||||
}
|
||||
gRPCAPIHandler.Stop()
|
||||
_ = store.Close()
|
||||
_ = eventStore.Close()
|
||||
log.Infof("stopped Management Service")
|
||||
_ = store.Close(ctx)
|
||||
_ = eventStore.Close(ctx)
|
||||
log.WithContext(ctx).Infof("stopped Management Service")
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func notifyStop(msg string) {
|
||||
func unaryInterceptor(
|
||||
ctx context.Context,
|
||||
req interface{},
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler,
|
||||
) (interface{}, error) {
|
||||
reqID := uuid.New().String()
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, formatter.ExecutionContextKey, formatter.GRPCSource)
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
func streamInterceptor(
|
||||
srv interface{},
|
||||
ss grpc.ServerStream,
|
||||
info *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler,
|
||||
) error {
|
||||
reqID := uuid.New().String()
|
||||
wrapped := grpcMiddleware.WrapServerStream(ss)
|
||||
//nolint
|
||||
ctx := context.WithValue(ss.Context(), formatter.ExecutionContextKey, formatter.GRPCSource)
|
||||
//nolint
|
||||
wrapped.WrappedContext = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
|
||||
return handler(srv, wrapped)
|
||||
}
|
||||
|
||||
func notifyStop(ctx context.Context, msg string) {
|
||||
select {
|
||||
case stopCh <- 1:
|
||||
log.Error(msg)
|
||||
log.WithContext(ctx).Error(msg)
|
||||
default:
|
||||
// stop has been already called, nothing to report
|
||||
}
|
||||
}
|
||||
|
||||
func getInstallationID(store server.Store) (string, error) {
|
||||
func getInstallationID(ctx context.Context, store server.Store) (string, error) {
|
||||
installationID := store.GetInstallationID()
|
||||
if installationID != "" {
|
||||
return installationID, nil
|
||||
}
|
||||
|
||||
installationID = strings.ToUpper(uuid.New().String())
|
||||
err := store.SaveInstallationID(installationID)
|
||||
err := store.SaveInstallationID(ctx, installationID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return installationID, nil
|
||||
}
|
||||
|
||||
func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) {
|
||||
func serveGRPC(ctx context.Context, grpcServer *grpc.Server, port int) (net.Listener, error) {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -379,22 +419,22 @@ func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) {
|
||||
go func() {
|
||||
err := grpcServer.Serve(listener)
|
||||
if err != nil {
|
||||
notifyStop(fmt.Sprintf("failed running gRPC server on port %d: %v", port, err))
|
||||
notifyStop(ctx, fmt.Sprintf("failed running gRPC server on port %d: %v", port, err))
|
||||
}
|
||||
}()
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
func serveHTTP(httpListener net.Listener, handler http.Handler) {
|
||||
func serveHTTP(ctx context.Context, httpListener net.Listener, handler http.Handler) {
|
||||
go func() {
|
||||
err := http.Serve(httpListener, handler)
|
||||
if err != nil {
|
||||
notifyStop(fmt.Sprintf("failed running HTTP server: %v", err))
|
||||
notifyStop(ctx, fmt.Sprintf("failed running HTTP server: %v", err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func serveGRPCWithHTTP(listener net.Listener, handler http.Handler, tlsEnabled bool) {
|
||||
func serveGRPCWithHTTP(ctx context.Context, listener net.Listener, handler http.Handler, tlsEnabled bool) {
|
||||
go func() {
|
||||
var err error
|
||||
if tlsEnabled {
|
||||
@ -411,7 +451,7 @@ func serveGRPCWithHTTP(listener net.Listener, handler http.Handler, tlsEnabled b
|
||||
if err != nil {
|
||||
select {
|
||||
case stopCh <- 1:
|
||||
log.Errorf("failed to serve HTTP and gRPC server: %v", err)
|
||||
log.WithContext(ctx).Errorf("failed to serve HTTP and gRPC server: %v", err)
|
||||
default:
|
||||
// stop has been already called, nothing to report
|
||||
}
|
||||
@ -431,7 +471,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle
|
||||
})
|
||||
}
|
||||
|
||||
func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
|
||||
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) {
|
||||
loadedConfig := &server.Config{}
|
||||
_, err := util.ReadJson(mgmtConfigPath, loadedConfig)
|
||||
if err != nil {
|
||||
@ -452,26 +492,26 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
|
||||
oidcEndpoint := loadedConfig.HttpConfig.OIDCConfigEndpoint
|
||||
if oidcEndpoint != "" {
|
||||
// if OIDCConfigEndpoint is specified, we can load DeviceAuthEndpoint and TokenEndpoint automatically
|
||||
log.Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint)
|
||||
oidcConfig, err := fetchOIDCConfig(oidcEndpoint)
|
||||
log.WithContext(ctx).Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint)
|
||||
oidcConfig, err := fetchOIDCConfig(ctx, oidcEndpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
|
||||
log.WithContext(ctx).Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
|
||||
|
||||
log.Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
|
||||
log.WithContext(ctx).Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
|
||||
oidcConfig.Issuer, loadedConfig.HttpConfig.AuthIssuer)
|
||||
loadedConfig.HttpConfig.AuthIssuer = oidcConfig.Issuer
|
||||
|
||||
log.Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s",
|
||||
log.WithContext(ctx).Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s",
|
||||
oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation)
|
||||
loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
|
||||
|
||||
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(server.NONE)) {
|
||||
log.Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
|
||||
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
|
||||
oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint)
|
||||
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||
log.Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s",
|
||||
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s",
|
||||
oidcConfig.DeviceAuthEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint)
|
||||
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint = oidcConfig.DeviceAuthEndpoint
|
||||
|
||||
@ -479,7 +519,7 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s",
|
||||
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s",
|
||||
u.Host, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain)
|
||||
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
|
||||
|
||||
@ -489,10 +529,10 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
|
||||
}
|
||||
|
||||
if loadedConfig.PKCEAuthorizationFlow != nil {
|
||||
log.Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
|
||||
log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
|
||||
oidcConfig.TokenEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint)
|
||||
loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||
log.Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s",
|
||||
log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s",
|
||||
oidcConfig.AuthorizationEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint)
|
||||
loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
|
||||
}
|
||||
@ -501,8 +541,8 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) {
|
||||
return loadedConfig, err
|
||||
}
|
||||
|
||||
func updateMgmtConfig(path string, config *server.Config) error {
|
||||
return util.DirectWriteJson(path, config)
|
||||
func updateMgmtConfig(ctx context.Context, path string, config *server.Config) error {
|
||||
return util.DirectWriteJson(ctx, path, config)
|
||||
}
|
||||
|
||||
// OIDCConfigResponse used for parsing OIDC config response
|
||||
@ -515,7 +555,7 @@ type OIDCConfigResponse struct {
|
||||
}
|
||||
|
||||
// fetchOIDCConfig fetches OIDC configuration from the IDP
|
||||
func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
|
||||
func fetchOIDCConfig(ctx context.Context, oidcEndpoint string) (OIDCConfigResponse, error) {
|
||||
res, err := http.Get(oidcEndpoint)
|
||||
if err != nil {
|
||||
return OIDCConfigResponse{}, fmt.Errorf("failed fetching OIDC configuration from endpoint %s %v", oidcEndpoint, err)
|
||||
@ -524,7 +564,7 @@ func fetchOIDCConfig(oidcEndpoint string) (OIDCConfigResponse, error) {
|
||||
defer func() {
|
||||
err := res.Body.Close()
|
||||
if err != nil {
|
||||
log.Debugf("failed closing response body %v", err)
|
||||
log.WithContext(ctx).Debugf("failed closing response body %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -1,13 +1,16 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var shortUp = "Migrate JSON file store to SQLite store. Please make a backup of the JSON file before running this command."
|
||||
@ -26,10 +29,13 @@ var upCmd = &cobra.Command{
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
|
||||
if err := server.MigrateFileStoreToSqlite(mgmtDataDir); err != nil {
|
||||
//nolint
|
||||
ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource)
|
||||
|
||||
if err := server.MigrateFileStoreToSqlite(ctx, mgmtDataDir); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Info("Migration finished successfully")
|
||||
log.WithContext(ctx).Info("Migration finished successfully")
|
||||
|
||||
return nil
|
||||
},
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"encoding/json"
|
||||
@ -29,11 +30,11 @@ import (
|
||||
type MocIntegratedValidator struct {
|
||||
}
|
||||
|
||||
func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
||||
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
|
||||
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
|
||||
return update, nil
|
||||
}
|
||||
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
|
||||
@ -44,15 +45,15 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[s
|
||||
return validatedPeers, nil
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
|
||||
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
|
||||
return peer
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
|
||||
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
|
||||
return false, false, nil
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) PeerDeleted(_, _ string) error {
|
||||
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -60,7 +61,7 @@ func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)
|
||||
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) Stop() {
|
||||
func (MocIntegratedValidator) Stop(_ context.Context) {
|
||||
}
|
||||
|
||||
func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) {
|
||||
@ -85,7 +86,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Ac
|
||||
setupKey = key.Key
|
||||
}
|
||||
|
||||
_, _, _, err := manager.AddPeer(setupKey, userID, peer)
|
||||
_, _, _, err := manager.AddPeer(context.Background(), setupKey, userID, peer)
|
||||
if err != nil {
|
||||
t.Error("expected to add new peer successfully after creating new account, but failed", err)
|
||||
}
|
||||
@ -395,7 +396,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, testCase := range tt {
|
||||
account := newAccountWithId("account-1", userID, "netbird.io")
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io")
|
||||
account.UpdateSettings(&testCase.accountSettings)
|
||||
account.Network = network
|
||||
account.Peers = testCase.peers
|
||||
@ -409,7 +410,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
validatedPeers[p] = struct{}{}
|
||||
}
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(testCase.peerID, "netbird.io", validatedPeers)
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, "netbird.io", validatedPeers)
|
||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||
}
|
||||
@ -419,7 +420,7 @@ func TestNewAccount(t *testing.T) {
|
||||
domain := "netbird.io"
|
||||
userId := "account_creator"
|
||||
accountID := "account_id"
|
||||
account := newAccountWithId(accountID, userId, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain)
|
||||
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
|
||||
}
|
||||
|
||||
@ -430,7 +431,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(userID, "")
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -439,7 +440,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err = manager.Store.GetAccountByUser(userID)
|
||||
account, err = manager.Store.GetAccountByUser(context.Background(), userID)
|
||||
if err != nil {
|
||||
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID)
|
||||
return
|
||||
@ -630,11 +631,11 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
initAccount, err := manager.GetAccountByUserOrAccountID(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
|
||||
initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
|
||||
require.NoError(t, err, "create init user failed")
|
||||
|
||||
if testCase.inputUpdateAttrs {
|
||||
err = manager.updateAccountDomainAttributes(initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
|
||||
require.NoError(t, err, "update init user failed")
|
||||
}
|
||||
|
||||
@ -642,7 +643,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
||||
testCase.inputClaims.AccountId = initAccount.Id
|
||||
}
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(testCase.inputClaims)
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
|
||||
require.NoError(t, err, "support function failed")
|
||||
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
|
||||
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
|
||||
@ -661,12 +662,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
|
||||
initAccount := newAccountWithId("", userId, domain)
|
||||
initAccount := newAccountWithId(context.Background(), "", userId, domain)
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID := initAccount.Id
|
||||
acc, err := manager.GetAccountByUserOrAccountID(userId, accountID, domain)
|
||||
acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain)
|
||||
require.NoError(t, err, "create init user failed")
|
||||
// as initAccount was created without account id we have to take the id after account initialization
|
||||
// that happens inside the GetAccountByUserOrAccountID where the id is getting generated
|
||||
@ -682,18 +683,18 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("JWT groups disabled", func(t *testing.T) {
|
||||
account, _, err := manager.GetAccountFromToken(claims)
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
||||
})
|
||||
|
||||
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
err := manager.Store.SaveAccount(initAccount)
|
||||
err := manager.Store.SaveAccount(context.Background(), initAccount)
|
||||
require.NoError(t, err, "save account failed")
|
||||
require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist")
|
||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(claims)
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
||||
})
|
||||
@ -701,11 +702,11 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
t.Run("JWT groups enabled", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
|
||||
err := manager.Store.SaveAccount(initAccount)
|
||||
err := manager.Store.SaveAccount(context.Background(), initAccount)
|
||||
require.NoError(t, err, "save account failed")
|
||||
require.Len(t, manager.Store.GetAllAccounts(), 1, "only one account should exist")
|
||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(claims)
|
||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
||||
|
||||
@ -728,7 +729,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
|
||||
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
|
||||
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
@ -742,7 +743,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(account)
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@ -751,7 +752,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
Store: store,
|
||||
}
|
||||
|
||||
account, user, pat, err := am.GetAccountFromPAT(token)
|
||||
account, user, pat, err := am.GetAccountFromPAT(context.Background(), token)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting Account from PAT: %s", err)
|
||||
}
|
||||
@ -763,7 +764,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
|
||||
func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
|
||||
token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W"
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
@ -778,7 +779,7 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
err := store.SaveAccount(account)
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
@ -787,12 +788,12 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
|
||||
Store: store,
|
||||
}
|
||||
|
||||
err = am.MarkPATUsed("tokenId")
|
||||
err = am.MarkPATUsed(context.Background(), "tokenId")
|
||||
if err != nil {
|
||||
t.Fatalf("Error when marking PAT used: %s", err)
|
||||
}
|
||||
|
||||
account, err = am.Store.GetAccount("account_id")
|
||||
account, err = am.Store.GetAccount(context.Background(), "account_id")
|
||||
if err != nil {
|
||||
t.Fatalf("Error when getting account: %s", err)
|
||||
}
|
||||
@ -807,7 +808,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
userId := "test_user"
|
||||
account, err := manager.GetOrCreateAccountByUser(userId, "")
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -815,7 +816,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
t.Fatalf("expected to create an account for a user %s", userId)
|
||||
}
|
||||
|
||||
account, err = manager.Store.GetAccountByUser(userId)
|
||||
account, err = manager.Store.GetAccountByUser(context.Background(), userId)
|
||||
if err != nil {
|
||||
t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId)
|
||||
}
|
||||
@ -834,7 +835,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
||||
|
||||
userId := "test_user"
|
||||
domain := "hotmail.com"
|
||||
account, err := manager.GetOrCreateAccountByUser(userId, domain)
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -848,7 +849,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
||||
|
||||
domain = "gmail.com"
|
||||
|
||||
account, err = manager.GetOrCreateAccountByUser(userId, domain)
|
||||
account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
|
||||
if err != nil {
|
||||
t.Fatalf("got the following error while retrieving existing acc: %v", err)
|
||||
}
|
||||
@ -871,7 +872,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
|
||||
|
||||
userId := "test_user"
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(userId, "", "")
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -880,20 +881,20 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID("", account.Id, "")
|
||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
|
||||
if err != nil {
|
||||
t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id)
|
||||
}
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID("", "", "")
|
||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "")
|
||||
if err == nil {
|
||||
t.Errorf("expected an error when user and account IDs are empty")
|
||||
}
|
||||
}
|
||||
|
||||
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) {
|
||||
account := newAccountWithId(accountID, userID, domain)
|
||||
err := am.Store.SaveAccount(account)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -915,7 +916,7 @@ func TestAccountManager_GetAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
// AddAccount has been already tested so we can assume it is correct and compare results
|
||||
getAccount, err := manager.Store.GetAccount(account.Id)
|
||||
getAccount, err := manager.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -952,12 +953,12 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.DeleteAccount(account.Id, userId)
|
||||
err = manager.DeleteAccount(context.Background(), account.Id, userId)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
getAccount, err := manager.Store.GetAccount(account.Id)
|
||||
getAccount, err := manager.Store.GetAccount(context.Background(), account.Id)
|
||||
if err == nil {
|
||||
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
|
||||
}
|
||||
@ -978,7 +979,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
|
||||
serial := account.Network.CurrentSerial() // should be 0
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
@ -997,7 +998,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
expectedSetupKey := setupKey.Key
|
||||
|
||||
peer, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
@ -1006,7 +1007,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err = manager.Store.GetAccount(account.Id)
|
||||
account, err = manager.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -1045,7 +1046,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(userID, "netbird.cloud")
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -1065,7 +1066,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
expectedUserID := userID
|
||||
|
||||
peer, _, _, err := manager.AddPeer("", userID, &nbpeer.Peer{
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
@ -1074,7 +1075,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err = manager.Store.GetAccount(account.Id)
|
||||
account, err = manager.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -1121,7 +1122,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
@ -1140,7 +1141,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
}
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
|
||||
peer, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
@ -1156,14 +1157,14 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
peer2 := getPeer()
|
||||
peer3 := getPeer()
|
||||
|
||||
account, err = manager.Store.GetAccount(account.Id)
|
||||
account, err = manager.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(peer1.ID)
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
group := group.Group{
|
||||
ID: "group-id",
|
||||
@ -1197,7 +1198,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.SaveGroup(account.Id, userID, &group); err != nil {
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
@ -1217,7 +1218,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.DeletePolicy(account.Id, account.Policies[0].ID, userID); err != nil {
|
||||
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
@ -1237,7 +1238,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.SavePolicy(account.Id, userID, &policy); err != nil {
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
@ -1256,7 +1257,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.DeletePeer(account.Id, peer3.ID, userID); err != nil {
|
||||
if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil {
|
||||
t.Errorf("delete peer: %v", err)
|
||||
return
|
||||
}
|
||||
@ -1277,9 +1278,9 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
}()
|
||||
|
||||
// clean policy is pre requirement for delete group
|
||||
_ = manager.DeletePolicy(account.Id, policy.ID, userID)
|
||||
_ = manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID)
|
||||
|
||||
if err := manager.DeleteGroup(account.Id, "", group.ID); err != nil {
|
||||
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
|
||||
t.Errorf("delete group: %v", err)
|
||||
return
|
||||
}
|
||||
@ -1301,7 +1302,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
@ -1315,7 +1316,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
|
||||
peerKey := key.PublicKey().String()
|
||||
|
||||
peer, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: peerKey},
|
||||
})
|
||||
@ -1324,12 +1325,12 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
err = manager.DeletePeer(account.Id, peerKey, userID)
|
||||
err = manager.DeletePeer(context.Background(), account.Id, peerKey, userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
account, err = manager.Store.GetAccount(account.Id)
|
||||
account, err = manager.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -1357,7 +1358,7 @@ func getEvent(t *testing.T, accountID string, manager AccountManager, eventType
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("no PeerAddedWithSetupKey event was generated")
|
||||
default:
|
||||
events, err := manager.GetEvents(accountID, userID)
|
||||
events, err := manager.GetEvents(context.Background(), accountID, userID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -1389,7 +1390,7 @@ func TestGetUsersFromAccount(t *testing.T) {
|
||||
account.Users[user.Id] = user
|
||||
}
|
||||
|
||||
userInfos, err := manager.GetUsersFromAccount(accountId, "1")
|
||||
userInfos, err := manager.GetUsersFromAccount(context.Background(), accountId, "1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -1500,7 +1501,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
|
||||
routes := account.getRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
|
||||
|
||||
assert.Len(t, routes, 2)
|
||||
routeIDs := make(map[route.ID]struct{}, 2)
|
||||
@ -1510,7 +1511,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
|
||||
assert.Contains(t, routeIDs, route.ID("route-2"))
|
||||
assert.Contains(t, routeIDs, route.ID("route-3"))
|
||||
|
||||
emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
|
||||
emptyRoutes := account.getRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
|
||||
|
||||
assert.Len(t, emptyRoutes, 0)
|
||||
}
|
||||
@ -1645,7 +1646,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
assert.NotNil(t, account.Settings)
|
||||
@ -1657,23 +1658,23 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peer, _, _, err := manager.AddPeer("", userID, &nbpeer.Peer{
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
LoginExpirationEnabled: true,
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
@ -1682,10 +1683,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
manager.peerLoginExpiry = &MockScheduler{
|
||||
CancelFunc: func(IDs []string) {
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {
|
||||
wg.Done()
|
||||
},
|
||||
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
wg.Done()
|
||||
},
|
||||
}
|
||||
@ -1693,11 +1694,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
// disable expiration first
|
||||
update := peer.Copy()
|
||||
update.LoginExpirationEnabled = false
|
||||
_, err = manager.UpdatePeer(account.Id, userID, update)
|
||||
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
|
||||
require.NoError(t, err, "unable to update peer")
|
||||
// enabling expiration should trigger the routine
|
||||
update.LoginExpirationEnabled = true
|
||||
_, err = manager.UpdatePeer(account.Id, userID, update)
|
||||
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update)
|
||||
require.NoError(t, err, "unable to update peer")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@ -1710,18 +1711,18 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
_, _, _, err = manager.AddPeer("", userID, &nbpeer.Peer{
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
LoginExpirationEnabled: true,
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
@ -1730,18 +1731,18 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
manager.peerLoginExpiry = &MockScheduler{
|
||||
CancelFunc: func(IDs []string) {
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {
|
||||
wg.Done()
|
||||
},
|
||||
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
wg.Done()
|
||||
},
|
||||
}
|
||||
|
||||
account, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@ -1754,35 +1755,35 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
_, err = manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
_, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
_, _, _, err = manager.AddPeer("", userID, &nbpeer.Peer{
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
LoginExpirationEnabled: true,
|
||||
})
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
manager.peerLoginExpiry = &MockScheduler{
|
||||
CancelFunc: func(IDs []string) {
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {
|
||||
wg.Done()
|
||||
},
|
||||
ScheduleFunc: func(in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
wg.Done()
|
||||
},
|
||||
}
|
||||
// enabling PeerLoginExpirationEnabled should trigger the expiration job
|
||||
account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
@ -1795,7 +1796,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
wg.Add(1)
|
||||
|
||||
// disabling PeerLoginExpirationEnabled should trigger cancel
|
||||
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
@ -1810,10 +1811,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
account, err := manager.GetAccountByUserOrAccountID(userID, "", "")
|
||||
account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
updated, err := manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||
updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
@ -1821,19 +1822,19 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
account, err = manager.GetAccountByUserOrAccountID("", account.Id, "")
|
||||
account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
|
||||
require.NoError(t, err, "unable to get account by ID")
|
||||
|
||||
assert.False(t, account.Settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Second,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour * 24 * 181,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
@ -2294,7 +2295,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
manager, err := BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -2305,7 +2306,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
func createStore(t *testing.T) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, cleanUp, err := NewTestStoreFromJson(dataDir)
|
||||
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@ -86,7 +87,7 @@ type Store struct {
|
||||
}
|
||||
|
||||
// NewSQLiteStore creates a new Store with an event table if not exists.
|
||||
func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
|
||||
func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (*Store, error) {
|
||||
dbFile := filepath.Join(dataDir, eventSinkDB)
|
||||
db, err := sql.Open("sqlite3", dbFile)
|
||||
if err != nil {
|
||||
@ -111,7 +112,7 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = updateDeletedUsersTable(db)
|
||||
err = updateDeletedUsersTable(ctx, db)
|
||||
if err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
@ -153,7 +154,7 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
|
||||
func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) {
|
||||
events := make([]*activity.Event, 0)
|
||||
var cryptErr error
|
||||
for result.Next() {
|
||||
@ -235,14 +236,14 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
|
||||
}
|
||||
|
||||
if cryptErr != nil {
|
||||
log.Warnf("%s", cryptErr)
|
||||
log.WithContext(ctx).Warnf("%s", cryptErr)
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// Get returns "limit" number of events from index ordered descending or ascending by a timestamp
|
||||
func (store *Store) Get(accountID string, offset, limit int, descending bool) ([]*activity.Event, error) {
|
||||
func (store *Store) Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*activity.Event, error) {
|
||||
stmt := store.selectDescStatement
|
||||
if !descending {
|
||||
stmt = store.selectAscStatement
|
||||
@ -254,11 +255,11 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([
|
||||
}
|
||||
|
||||
defer result.Close() //nolint
|
||||
return store.processResult(result)
|
||||
return store.processResult(ctx, result)
|
||||
}
|
||||
|
||||
// Save an event in the SQLite events table end encrypt the "email" element in meta map
|
||||
func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
|
||||
func (store *Store) Save(_ context.Context, event *activity.Event) (*activity.Event, error) {
|
||||
var jsonMeta string
|
||||
meta, err := store.saveDeletedUserEmailAndNameInEncrypted(event)
|
||||
if err != nil {
|
||||
@ -317,15 +318,15 @@ func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event
|
||||
}
|
||||
|
||||
// Close the Store
|
||||
func (store *Store) Close() error {
|
||||
func (store *Store) Close(_ context.Context) error {
|
||||
if store.db != nil {
|
||||
return store.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateDeletedUsersTable(db *sql.DB) error {
|
||||
log.Debugf("check deleted_users table version")
|
||||
func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
|
||||
log.WithContext(ctx).Debugf("check deleted_users table version")
|
||||
rows, err := db.Query(`PRAGMA table_info(deleted_users);`)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -360,7 +361,7 @@ func updateDeletedUsersTable(db *sql.DB) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("update delted_users table")
|
||||
log.WithContext(ctx).Debugf("update delted_users table")
|
||||
_, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
|
||||
return err
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
@ -13,17 +14,17 @@ import (
|
||||
func TestNewSQLiteStore(t *testing.T) {
|
||||
dataDir := t.TempDir()
|
||||
key, _ := GenerateKey()
|
||||
store, err := NewSQLiteStore(dataDir, key)
|
||||
store, err := NewSQLiteStore(context.Background(), dataDir, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer store.Close() //nolint
|
||||
defer store.Close(context.Background()) //nolint
|
||||
|
||||
accountID := "account_1"
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = store.Save(&activity.Event{
|
||||
_, err = store.Save(context.Background(), &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.PeerAddedByUser,
|
||||
InitiatorID: "user_" + fmt.Sprint(i),
|
||||
@ -36,7 +37,7 @@ func TestNewSQLiteStore(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
result, err := store.Get(accountID, 0, 10, false)
|
||||
result, err := store.Get(context.Background(), accountID, 0, 10, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -45,7 +46,7 @@ func TestNewSQLiteStore(t *testing.T) {
|
||||
assert.Len(t, result, 10)
|
||||
assert.True(t, result[0].Timestamp.Before(result[len(result)-1].Timestamp))
|
||||
|
||||
result, err = store.Get(accountID, 0, 5, true)
|
||||
result, err = store.Get(context.Background(), accountID, 0, 5, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
|
@ -1,15 +1,18 @@
|
||||
package activity
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Store provides an interface to store or stream events.
|
||||
type Store interface {
|
||||
// Save an event in the store
|
||||
Save(event *Event) (*Event, error)
|
||||
Save(ctx context.Context, event *Event) (*Event, error)
|
||||
// Get returns "limit" number of events from the "offset" index ordered descending or ascending by a timestamp
|
||||
Get(accountID string, offset, limit int, descending bool) ([]*Event, error)
|
||||
Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*Event, error)
|
||||
// Close the sink flushing events if necessary
|
||||
Close() error
|
||||
Close(ctx context.Context) error
|
||||
}
|
||||
|
||||
// InMemoryEventStore implements the Store interface storing data in-memory
|
||||
@ -20,7 +23,7 @@ type InMemoryEventStore struct {
|
||||
}
|
||||
|
||||
// Save sets the Event.ID to 1
|
||||
func (store *InMemoryEventStore) Save(event *Event) (*Event, error) {
|
||||
func (store *InMemoryEventStore) Save(_ context.Context, event *Event) (*Event, error) {
|
||||
store.mu.Lock()
|
||||
defer store.mu.Unlock()
|
||||
if store.events == nil {
|
||||
@ -33,7 +36,7 @@ func (store *InMemoryEventStore) Save(event *Event) (*Event, error) {
|
||||
}
|
||||
|
||||
// Get returns a list of ALL events that belong to the given accountID without taking offset, limit and order into consideration
|
||||
func (store *InMemoryEventStore) Get(accountID string, offset, limit int, descending bool) ([]*Event, error) {
|
||||
func (store *InMemoryEventStore) Get(_ context.Context, accountID string, offset, limit int, descending bool) ([]*Event, error) {
|
||||
store.mu.Lock()
|
||||
defer store.mu.Unlock()
|
||||
events := make([]*Event, 0)
|
||||
@ -46,7 +49,7 @@ func (store *InMemoryEventStore) Get(accountID string, offset, limit int, descen
|
||||
}
|
||||
|
||||
// Close cleans up the event list
|
||||
func (store *InMemoryEventStore) Close() error {
|
||||
func (store *InMemoryEventStore) Close(_ context.Context) error {
|
||||
store.mu.Lock()
|
||||
defer store.mu.Unlock()
|
||||
store.events = make([]*Event, 0)
|
||||
|
8
management/server/context/keys.go
Normal file
8
management/server/context/keys.go
Normal file
@ -0,0 +1,8 @@
|
||||
package context
|
||||
|
||||
const (
|
||||
RequestIDKey = "requestID"
|
||||
AccountIDKey = "accountID"
|
||||
UserIDKey = "userID"
|
||||
PeerIDKey = "peerID"
|
||||
)
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
@ -34,11 +35,11 @@ func (d DNSSettings) Copy() DNSSettings {
|
||||
}
|
||||
|
||||
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
|
||||
func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -56,11 +57,11 @@ func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string)
|
||||
}
|
||||
|
||||
// SaveDNSSettings validates a user role and updates the account's DNS settings
|
||||
func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -89,7 +90,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string
|
||||
account.DNSSettings = dnsSettingsToSave.Copy()
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -97,17 +98,17 @@ func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string
|
||||
for _, id := range addedGroups {
|
||||
group := account.GetGroup(id)
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
|
||||
}
|
||||
|
||||
removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
for _, id := range removedGroups {
|
||||
group := account.GetGroup(id)
|
||||
meta := map[string]any{"group": group.Name, "group_id": group.ID}
|
||||
am.StoreEvent(userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
|
||||
}
|
||||
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -149,9 +150,9 @@ func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
|
||||
return protoUpdate
|
||||
}
|
||||
|
||||
func getPeersCustomZone(account *Account, dnsDomain string) nbdns.CustomZone {
|
||||
func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string) nbdns.CustomZone {
|
||||
if dnsDomain == "" {
|
||||
log.Errorf("no dns domain is set, returning empty zone")
|
||||
log.WithContext(ctx).Errorf("no dns domain is set, returning empty zone")
|
||||
return nbdns.CustomZone{}
|
||||
}
|
||||
|
||||
@ -161,7 +162,7 @@ func getPeersCustomZone(account *Account, dnsDomain string) nbdns.CustomZone {
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if peer.DNSLabel == "" {
|
||||
log.Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
|
||||
log.WithContext(ctx).Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
@ -210,14 +211,14 @@ func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) {
|
||||
func addPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels lookupMap) {
|
||||
for _, peer := range account.Peers {
|
||||
label, err := getPeerHostLabel(peer.Name, peerLabels)
|
||||
if err != nil {
|
||||
log.Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err)
|
||||
log.WithContext(ctx).Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err)
|
||||
label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels)
|
||||
if err != nil {
|
||||
log.Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err)
|
||||
log.WithContext(ctx).Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
@ -35,7 +36,7 @@ func TestGetDNSSettings(t *testing.T) {
|
||||
t.Fatal("failed to init testing account")
|
||||
}
|
||||
|
||||
dnsSettings, err := am.GetDNSSettings(account.Id, dnsAdminUserID)
|
||||
dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
|
||||
if err != nil {
|
||||
t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
|
||||
}
|
||||
@ -48,12 +49,12 @@ func TestGetDNSSettings(t *testing.T) {
|
||||
DisabledManagementGroups: []string{group1ID},
|
||||
}
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Error("failed to save testing account with new DNS settings")
|
||||
}
|
||||
|
||||
dnsSettings, err = am.GetDNSSettings(account.Id, dnsAdminUserID)
|
||||
dnsSettings, err = am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID)
|
||||
if err != nil {
|
||||
t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err)
|
||||
}
|
||||
@ -62,7 +63,7 @@ func TestGetDNSSettings(t *testing.T) {
|
||||
t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups)
|
||||
}
|
||||
|
||||
_, err = am.GetDNSSettings(account.Id, dnsRegularUserID)
|
||||
_, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID)
|
||||
if err == nil {
|
||||
t.Errorf("An error should be returned when getting the DNS settings with a regular user")
|
||||
}
|
||||
@ -122,7 +123,7 @@ func TestSaveDNSSettings(t *testing.T) {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
err = am.SaveDNSSettings(account.Id, testCase.userID, testCase.inputSettings)
|
||||
err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings)
|
||||
if err != nil {
|
||||
if testCase.shouldFail {
|
||||
return
|
||||
@ -130,7 +131,7 @@ func TestSaveDNSSettings(t *testing.T) {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
updatedAccount, err := am.Store.GetAccount(account.Id)
|
||||
updatedAccount, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Errorf("should be able to retrieve updated account, got err: %s", err)
|
||||
}
|
||||
@ -164,7 +165,7 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
newAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
|
||||
newAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers")
|
||||
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
|
||||
@ -173,14 +174,14 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
||||
dnsSettings := account.DNSSettings.Copy()
|
||||
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
|
||||
account.DNSSettings = dnsSettings
|
||||
err = am.Store.SaveAccount(account)
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedAccountDNSConfig, err := am.GetNetworkMap(peer1.ID)
|
||||
updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group")
|
||||
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group")
|
||||
peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID)
|
||||
peer2AccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer2.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group")
|
||||
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group")
|
||||
@ -194,13 +195,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
|
||||
}
|
||||
|
||||
func createDNSStore(t *testing.T) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, cleanUp, err := NewTestStoreFromJson(dataDir)
|
||||
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -244,28 +245,28 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
||||
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(dnsAccountID, dnsAdminUserID, domain)
|
||||
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain)
|
||||
|
||||
account.Users[dnsRegularUserID] = &User{
|
||||
Id: dnsRegularUserID,
|
||||
Role: UserRoleUser,
|
||||
}
|
||||
|
||||
err := am.Store.SaveAccount(account)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
savedPeer1, _, _, err := am.AddPeer("", dnsAdminUserID, peer1)
|
||||
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, _, _, err = am.AddPeer("", dnsAdminUserID, peer2)
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account, err = am.Store.GetAccount(account.Id)
|
||||
account, err = am.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -312,10 +313,10 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
||||
Groups: []string{allGroup.ID},
|
||||
}
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return am.Store.GetAccount(account.Id)
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -51,13 +52,15 @@ func NewEphemeralManager(store Store, accountManager AccountManager) *EphemeralM
|
||||
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
|
||||
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
|
||||
// head.
|
||||
func (e *EphemeralManager) LoadInitialPeers() {
|
||||
func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) {
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
|
||||
e.loadEphemeralPeers()
|
||||
e.loadEphemeralPeers(ctx)
|
||||
if e.headPeer != nil {
|
||||
e.timer = time.AfterFunc(ephemeralLifeTime, e.cleanup)
|
||||
e.timer = time.AfterFunc(ephemeralLifeTime, func() {
|
||||
e.cleanup(ctx)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -73,12 +76,12 @@ func (e *EphemeralManager) Stop() {
|
||||
|
||||
// OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer
|
||||
// is active the manager will not delete it while it is active.
|
||||
func (e *EphemeralManager) OnPeerConnected(peer *nbpeer.Peer) {
|
||||
func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Peer) {
|
||||
if !peer.Ephemeral {
|
||||
return
|
||||
}
|
||||
|
||||
log.Tracef("remove peer from ephemeral list: %s", peer.ID)
|
||||
log.WithContext(ctx).Tracef("remove peer from ephemeral list: %s", peer.ID)
|
||||
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
@ -94,16 +97,16 @@ func (e *EphemeralManager) OnPeerConnected(peer *nbpeer.Peer) {
|
||||
|
||||
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
|
||||
// is inactive it will be deleted after the ephemeralLifeTime period.
|
||||
func (e *EphemeralManager) OnPeerDisconnected(peer *nbpeer.Peer) {
|
||||
func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) {
|
||||
if !peer.Ephemeral {
|
||||
return
|
||||
}
|
||||
|
||||
log.Tracef("add peer to ephemeral list: %s", peer.ID)
|
||||
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
|
||||
|
||||
a, err := e.store.GetAccountByPeerID(peer.ID)
|
||||
a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID)
|
||||
if err != nil {
|
||||
log.Errorf("failed to add peer to ephemeral list: %s", err)
|
||||
log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -116,12 +119,14 @@ func (e *EphemeralManager) OnPeerDisconnected(peer *nbpeer.Peer) {
|
||||
|
||||
e.addPeer(peer.ID, a, newDeadLine())
|
||||
if e.timer == nil {
|
||||
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup)
|
||||
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
|
||||
e.cleanup(ctx)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) loadEphemeralPeers() {
|
||||
accounts := e.store.GetAllAccounts()
|
||||
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
|
||||
accounts := e.store.GetAllAccounts(context.Background())
|
||||
t := newDeadLine()
|
||||
count := 0
|
||||
for _, a := range accounts {
|
||||
@ -132,10 +137,10 @@ func (e *EphemeralManager) loadEphemeralPeers() {
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Debugf("loaded ephemeral peer(s): %d", count)
|
||||
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count)
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) cleanup() {
|
||||
func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
log.Tracef("on ephemeral cleanup")
|
||||
deletePeers := make(map[string]*ephemeralPeer)
|
||||
|
||||
@ -154,7 +159,9 @@ func (e *EphemeralManager) cleanup() {
|
||||
}
|
||||
|
||||
if e.headPeer != nil {
|
||||
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), e.cleanup)
|
||||
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
|
||||
e.cleanup(ctx)
|
||||
})
|
||||
} else {
|
||||
e.timer = nil
|
||||
}
|
||||
@ -162,10 +169,10 @@ func (e *EphemeralManager) cleanup() {
|
||||
e.peersLock.Unlock()
|
||||
|
||||
for id, p := range deletePeers {
|
||||
log.Debugf("delete ephemeral peer: %s", id)
|
||||
err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator)
|
||||
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
|
||||
err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete ephemeral peer: %s", err)
|
||||
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
@ -13,11 +14,11 @@ type MockStore struct {
|
||||
account *Account
|
||||
}
|
||||
|
||||
func (s *MockStore) GetAllAccounts() []*Account {
|
||||
func (s *MockStore) GetAllAccounts(_ context.Context) []*Account {
|
||||
return []*Account{s.account}
|
||||
}
|
||||
|
||||
func (s *MockStore) GetAccountByPeerID(peerId string) (*Account, error) {
|
||||
func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) {
|
||||
_, ok := s.account.Peers[peerId]
|
||||
if ok {
|
||||
return s.account, nil
|
||||
@ -31,7 +32,7 @@ type MocAccountManager struct {
|
||||
store *MockStore
|
||||
}
|
||||
|
||||
func (a MocAccountManager) DeletePeer(accountID, peerID, userID string) error {
|
||||
func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
|
||||
delete(a.store.account.Peers, peerID)
|
||||
return nil //nolint:nil
|
||||
}
|
||||
@ -52,9 +53,9 @@ func TestNewManager(t *testing.T) {
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
|
||||
mgr := NewEphemeralManager(store, am)
|
||||
mgr.loadEphemeralPeers()
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
startTime = startTime.Add(ephemeralLifeTime + 1)
|
||||
mgr.cleanup()
|
||||
mgr.cleanup(context.Background())
|
||||
|
||||
if len(store.account.Peers) != numberOfPeers {
|
||||
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers))
|
||||
@ -77,11 +78,11 @@ func TestNewManagerPeerConnected(t *testing.T) {
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
|
||||
mgr := NewEphemeralManager(store, am)
|
||||
mgr.loadEphemeralPeers()
|
||||
mgr.OnPeerConnected(store.account.Peers["ephemeral_peer_0"])
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
|
||||
|
||||
startTime = startTime.Add(ephemeralLifeTime + 1)
|
||||
mgr.cleanup()
|
||||
mgr.cleanup(context.Background())
|
||||
|
||||
expected := numberOfPeers + 1
|
||||
if len(store.account.Peers) != expected {
|
||||
@ -105,15 +106,15 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
|
||||
mgr := NewEphemeralManager(store, am)
|
||||
mgr.loadEphemeralPeers()
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
for _, v := range store.account.Peers {
|
||||
mgr.OnPeerConnected(v)
|
||||
mgr.OnPeerConnected(context.Background(), v)
|
||||
|
||||
}
|
||||
mgr.OnPeerDisconnected(store.account.Peers["ephemeral_peer_0"])
|
||||
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
|
||||
|
||||
startTime = startTime.Add(ephemeralLifeTime + 1)
|
||||
mgr.cleanup()
|
||||
mgr.cleanup(context.Background())
|
||||
|
||||
expected := numberOfPeers + numberOfEphemeralPeers - 1
|
||||
if len(store.account.Peers) != expected {
|
||||
@ -122,7 +123,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
}
|
||||
|
||||
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
|
||||
store.account = newAccountWithId("my account", "", "")
|
||||
store.account = newAccountWithId(context.Background(), "my account", "", "")
|
||||
|
||||
for i := 0; i < numberOfPeers; i++ {
|
||||
peerId := fmt.Sprintf("peer_%d", i)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@ -11,11 +12,11 @@ import (
|
||||
)
|
||||
|
||||
// GetEvents returns a list of activity events of an account
|
||||
func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -29,7 +30,7 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view events")
|
||||
}
|
||||
|
||||
events, err := am.eventStore.Get(accountID, 0, 10000, true)
|
||||
events, err := am.eventStore.Get(ctx, accountID, 0, 10000, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -54,10 +55,10 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
func (am *DefaultAccountManager) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
|
||||
go func() {
|
||||
_, err := am.eventStore.Save(&activity.Event{
|
||||
_, err := am.eventStore.Save(ctx, &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activityID,
|
||||
InitiatorID: initiatorID,
|
||||
@ -67,7 +68,7 @@ func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID str
|
||||
})
|
||||
if err != nil {
|
||||
// todo add metric
|
||||
log.Errorf("received an error while storing an activity event, error: %s", err)
|
||||
log.WithContext(ctx).Errorf("received an error while storing an activity event, error: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -13,7 +14,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac
|
||||
accountID string, count int) {
|
||||
t.Helper()
|
||||
for i := 0; i < count; i++ {
|
||||
_, err := manager.eventStore.Save(&activity.Event{
|
||||
_, err := manager.eventStore.Save(context.Background(), &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: typ,
|
||||
InitiatorID: initiatorID,
|
||||
@ -35,32 +36,32 @@ func TestDefaultAccountManager_GetEvents(t *testing.T) {
|
||||
accountID := "accountID"
|
||||
|
||||
t.Run("get empty events list", func(t *testing.T) {
|
||||
events, err := manager.GetEvents(accountID, userID)
|
||||
events, err := manager.GetEvents(context.Background(), accountID, userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
assert.Len(t, events, 0)
|
||||
_ = manager.eventStore.Close() //nolint
|
||||
_ = manager.eventStore.Close(context.Background()) //nolint
|
||||
})
|
||||
|
||||
t.Run("get events", func(t *testing.T) {
|
||||
generateAndStoreEvents(t, manager, activity.PeerAddedByUser, userID, "peer", accountID, 10)
|
||||
events, err := manager.GetEvents(accountID, userID)
|
||||
events, err := manager.GetEvents(context.Background(), accountID, userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Len(t, events, 10)
|
||||
_ = manager.eventStore.Close() //nolint
|
||||
_ = manager.eventStore.Close(context.Background()) //nolint
|
||||
})
|
||||
|
||||
t.Run("get events without duplicates", func(t *testing.T) {
|
||||
generateAndStoreEvents(t, manager, activity.UserJoined, userID, "", accountID, 10)
|
||||
events, err := manager.GetEvents(accountID, userID)
|
||||
events, err := manager.GetEvents(context.Background(), accountID, userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
assert.Len(t, events, 1)
|
||||
_ = manager.eventStore.Close() //nolint
|
||||
_ = manager.eventStore.Close(context.Background()) //nolint
|
||||
})
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@ -48,8 +49,8 @@ type FileStore struct {
|
||||
type StoredAccount struct{}
|
||||
|
||||
// NewFileStore restores a store from the file located in the datadir
|
||||
func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
|
||||
fs, err := restore(filepath.Join(dataDir, storeFileName))
|
||||
func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
|
||||
fs, err := restore(ctx, filepath.Join(dataDir, storeFileName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -58,27 +59,27 @@ func NewFileStore(dataDir string, metrics telemetry.AppMetrics) (*FileStore, err
|
||||
}
|
||||
|
||||
// NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir
|
||||
func NewFilestoreFromSqliteStore(sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
|
||||
store, err := NewFileStore(dataDir, metrics)
|
||||
func NewFilestoreFromSqliteStore(ctx context.Context, sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) {
|
||||
store, err := NewFileStore(ctx, dataDir, metrics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = store.SaveInstallationID(sqlStore.GetInstallationID())
|
||||
err = store.SaveInstallationID(ctx, sqlStore.GetInstallationID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, account := range sqlStore.GetAllAccounts() {
|
||||
for _, account := range sqlStore.GetAllAccounts(ctx) {
|
||||
store.Accounts[account.Id] = account
|
||||
}
|
||||
|
||||
return store, store.persist(store.storeFile)
|
||||
return store, store.persist(ctx, store.storeFile)
|
||||
}
|
||||
|
||||
// restore the state of the store from the file.
|
||||
// Creates a new empty store file if doesn't exist
|
||||
func restore(file string) (*FileStore, error) {
|
||||
func restore(ctx context.Context, file string) (*FileStore, error) {
|
||||
if _, err := os.Stat(file); os.IsNotExist(err) {
|
||||
// create a new FileStore if previously didn't exist (e.g. first run)
|
||||
s := &FileStore{
|
||||
@ -95,7 +96,7 @@ func restore(file string) (*FileStore, error) {
|
||||
storeFile: file,
|
||||
}
|
||||
|
||||
err = s.persist(file)
|
||||
err = s.persist(ctx, file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -165,7 +166,7 @@ func restore(file string) (*FileStore, error) {
|
||||
// for data migration. Can be removed once most base will be with labels
|
||||
existingLabels := account.getPeerDNSLabels()
|
||||
if len(existingLabels) != len(account.Peers) {
|
||||
addPeerLabelsToAccount(account, existingLabels)
|
||||
addPeerLabelsToAccount(ctx, account, existingLabels)
|
||||
}
|
||||
|
||||
// TODO: delete this block after migration
|
||||
@ -178,7 +179,7 @@ func restore(file string) (*FileStore, error) {
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
log.Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
|
||||
log.WithContext(ctx).Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
|
||||
// if the All group didn't exist we probably don't have routes to update
|
||||
continue
|
||||
}
|
||||
@ -236,7 +237,7 @@ func restore(file string) (*FileStore, error) {
|
||||
}
|
||||
|
||||
// we need this persist to apply changes we made to account.Peers (we set them to Disconnected)
|
||||
err = store.persist(store.storeFile)
|
||||
err = store.persist(ctx, store.storeFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -246,7 +247,7 @@ func restore(file string) (*FileStore, error) {
|
||||
|
||||
// persist account data to a file
|
||||
// It is recommended to call it with locking FileStore.mux
|
||||
func (s *FileStore) persist(file string) error {
|
||||
func (s *FileStore) persist(ctx context.Context, file string) error {
|
||||
start := time.Now()
|
||||
err := util.WriteJson(file, s)
|
||||
if err != nil {
|
||||
@ -256,23 +257,23 @@ func (s *FileStore) persist(file string) error {
|
||||
if s.metrics != nil {
|
||||
s.metrics.StoreMetrics().CountPersistenceDuration(took)
|
||||
}
|
||||
log.Debugf("took %d ms to persist the FileStore", took.Milliseconds())
|
||||
log.WithContext(ctx).Debugf("took %d ms to persist the FileStore", took.Milliseconds())
|
||||
return nil
|
||||
}
|
||||
|
||||
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
|
||||
func (s *FileStore) AcquireGlobalLock() (unlock func()) {
|
||||
log.Debugf("acquiring global lock")
|
||||
func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
|
||||
log.WithContext(ctx).Debugf("acquiring global lock")
|
||||
start := time.Now()
|
||||
s.globalAccountLock.Lock()
|
||||
|
||||
unlock = func() {
|
||||
s.globalAccountLock.Unlock()
|
||||
log.Debugf("released global lock in %v", time.Since(start))
|
||||
log.WithContext(ctx).Debugf("released global lock in %v", time.Since(start))
|
||||
}
|
||||
|
||||
took := time.Since(start)
|
||||
log.Debugf("took %v to acquire global lock", took)
|
||||
log.WithContext(ctx).Debugf("took %v to acquire global lock", took)
|
||||
if s.metrics != nil {
|
||||
s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took)
|
||||
}
|
||||
@ -281,8 +282,8 @@ func (s *FileStore) AcquireGlobalLock() (unlock func()) {
|
||||
}
|
||||
|
||||
// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock
|
||||
func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
|
||||
log.Debugf("acquiring lock for account %s", accountID)
|
||||
func (s *FileStore) AcquireAccountWriteLock(ctx context.Context, accountID string) (unlock func()) {
|
||||
log.WithContext(ctx).Debugf("acquiring lock for account %s", accountID)
|
||||
start := time.Now()
|
||||
value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{})
|
||||
mtx := value.(*sync.Mutex)
|
||||
@ -290,7 +291,7 @@ func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
|
||||
|
||||
unlock = func() {
|
||||
mtx.Unlock()
|
||||
log.Debugf("released lock for account %s in %v", accountID, time.Since(start))
|
||||
log.WithContext(ctx).Debugf("released lock for account %s in %v", accountID, time.Since(start))
|
||||
}
|
||||
|
||||
return unlock
|
||||
@ -298,11 +299,11 @@ func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) {
|
||||
|
||||
// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock
|
||||
// This method is still returns a write lock as file store can't handle read locks
|
||||
func (s *FileStore) AcquireAccountReadLock(accountID string) (unlock func()) {
|
||||
return s.AcquireAccountWriteLock(accountID)
|
||||
func (s *FileStore) AcquireAccountReadLock(ctx context.Context, accountID string) (unlock func()) {
|
||||
return s.AcquireAccountWriteLock(ctx, accountID)
|
||||
}
|
||||
|
||||
func (s *FileStore) SaveAccount(account *Account) error {
|
||||
func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@ -338,10 +339,10 @@ func (s *FileStore) SaveAccount(account *Account) error {
|
||||
s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id
|
||||
}
|
||||
|
||||
return s.persist(s.storeFile)
|
||||
return s.persist(ctx, s.storeFile)
|
||||
}
|
||||
|
||||
func (s *FileStore) DeleteAccount(account *Account) error {
|
||||
func (s *FileStore) DeleteAccount(ctx context.Context, account *Account) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@ -373,7 +374,7 @@ func (s *FileStore) DeleteAccount(account *Account) error {
|
||||
|
||||
delete(s.Accounts, account.Id)
|
||||
|
||||
return s.persist(s.storeFile)
|
||||
return s.persist(ctx, s.storeFile)
|
||||
}
|
||||
|
||||
// DeleteHashedPAT2TokenIDIndex removes an entry from the indexing map HashedPAT2TokenID
|
||||
@ -397,7 +398,7 @@ func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
||||
}
|
||||
|
||||
// GetAccountByPrivateDomain returns account by private domain
|
||||
func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
||||
func (s *FileStore) GetAccountByPrivateDomain(_ context.Context, domain string) (*Account, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@ -415,7 +416,7 @@ func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {
|
||||
}
|
||||
|
||||
// GetAccountBySetupKey returns account by setup key id
|
||||
func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
||||
func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*Account, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@ -433,7 +434,7 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {
|
||||
}
|
||||
|
||||
// GetTokenIDByHashedToken returns the id of a personal access token by its hashed secret
|
||||
func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) {
|
||||
func (s *FileStore) GetTokenIDByHashedToken(_ context.Context, token string) (string, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@ -446,7 +447,7 @@ func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) {
|
||||
}
|
||||
|
||||
// GetUserByTokenID returns a User object a tokenID belongs to
|
||||
func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) {
|
||||
func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@ -469,7 +470,7 @@ func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) {
|
||||
}
|
||||
|
||||
// GetAllAccounts returns all accounts
|
||||
func (s *FileStore) GetAllAccounts() (all []*Account) {
|
||||
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
for _, a := range s.Accounts {
|
||||
@ -490,7 +491,7 @@ func (s *FileStore) getAccount(accountID string) (*Account, error) {
|
||||
}
|
||||
|
||||
// GetAccount returns an account for ID
|
||||
func (s *FileStore) GetAccount(accountID string) (*Account, error) {
|
||||
func (s *FileStore) GetAccount(_ context.Context, accountID string) (*Account, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@ -503,7 +504,7 @@ func (s *FileStore) GetAccount(accountID string) (*Account, error) {
|
||||
}
|
||||
|
||||
// GetAccountByUser returns a user account
|
||||
func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
|
||||
func (s *FileStore) GetAccountByUser(_ context.Context, userID string) (*Account, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@ -521,7 +522,7 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {
|
||||
}
|
||||
|
||||
// GetAccountByPeerID returns an account for a given peer ID
|
||||
func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
|
||||
func (s *FileStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@ -539,7 +540,7 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
|
||||
// check Account.Peers for a match
|
||||
if _, ok := account.Peers[peerID]; !ok {
|
||||
delete(s.PeerID2AccountID, peerID)
|
||||
log.Warnf("removed stale peerID %s to accountID %s index", peerID, accountID)
|
||||
log.WithContext(ctx).Warnf("removed stale peerID %s to accountID %s index", peerID, accountID)
|
||||
return nil, status.NewPeerNotFoundError(peerID)
|
||||
}
|
||||
|
||||
@ -547,7 +548,7 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) {
|
||||
}
|
||||
|
||||
// GetAccountByPeerPubKey returns an account for a given peer WireGuard public key
|
||||
func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
func (s *FileStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@ -572,14 +573,14 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
}
|
||||
if stale {
|
||||
delete(s.PeerKeyID2AccountID, peerKey)
|
||||
log.Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID)
|
||||
log.WithContext(ctx).Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID)
|
||||
return nil, status.NewPeerNotFoundError(peerKey)
|
||||
}
|
||||
|
||||
return account.Copy(), nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
|
||||
func (s *FileStore) GetAccountIDByPeerPubKey(_ context.Context, peerKey string) (string, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@ -603,7 +604,7 @@ func (s *FileStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) {
|
||||
func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (string, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@ -615,7 +616,7 @@ func (s *FileStore) GetAccountIDBySetupKey(setupKey string) (string, error) {
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) {
|
||||
func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@ -638,7 +639,7 @@ func (s *FileStore) GetPeerByPeerPubKey(peerKey string) (*nbpeer.Peer, error) {
|
||||
return nil, status.NewPeerNotFoundError(peerKey)
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountSettings(accountID string) (*Settings, error) {
|
||||
func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
@ -656,13 +657,13 @@ func (s *FileStore) GetInstallationID() string {
|
||||
}
|
||||
|
||||
// SaveInstallationID saves the installation ID
|
||||
func (s *FileStore) SaveInstallationID(ID string) error {
|
||||
func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
s.InstallationID = ID
|
||||
|
||||
return s.persist(s.storeFile)
|
||||
return s.persist(ctx, s.storeFile)
|
||||
}
|
||||
|
||||
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
|
||||
@ -732,13 +733,13 @@ func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *
|
||||
}
|
||||
|
||||
// Close the FileStore persisting data to disk
|
||||
func (s *FileStore) Close() error {
|
||||
func (s *FileStore) Close(ctx context.Context) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
log.Infof("closing FileStore")
|
||||
log.WithContext(ctx).Infof("closing FileStore")
|
||||
|
||||
return s.persist(s.storeFile)
|
||||
return s.persist(ctx, s.storeFile)
|
||||
}
|
||||
|
||||
// GetStoreEngine returns FileStoreEngine
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"net"
|
||||
"path/filepath"
|
||||
@ -27,12 +28,12 @@ func TestStalePeerIndices(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
require.NoError(t, err)
|
||||
|
||||
peerID := "some_peer"
|
||||
@ -42,24 +43,24 @@ func TestStalePeerIndices(t *testing.T) {
|
||||
Key: peerKey,
|
||||
}
|
||||
|
||||
err = store.SaveAccount(account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
account.DeletePeer(peerID)
|
||||
|
||||
err = store.SaveAccount(account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.GetAccountByPeerID(peerID)
|
||||
_, err = store.GetAccountByPeerID(context.Background(), peerID)
|
||||
require.Error(t, err, "expecting to get an error when found stale index")
|
||||
|
||||
_, err = store.GetAccountByPeerPubKey(peerKey)
|
||||
_, err = store.GetAccountByPeerPubKey(context.Background(), peerKey)
|
||||
require.Error(t, err, "expecting to get an error when found stale index")
|
||||
}
|
||||
|
||||
func TestNewStore(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
defer store.Close(context.Background())
|
||||
|
||||
if store.Accounts == nil || len(store.Accounts) != 0 {
|
||||
t.Errorf("expected to create a new empty Accounts map when creating a new FileStore")
|
||||
@ -88,9 +89,9 @@ func TestNewStore(t *testing.T) {
|
||||
|
||||
func TestSaveAccount(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
defer store.Close(context.Background())
|
||||
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
setupKey := GenerateDefaultSetupKey()
|
||||
account.SetupKeys[setupKey.Key] = setupKey
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
@ -103,7 +104,7 @@ func TestSaveAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
// SaveAccount should trigger persist
|
||||
err := store.SaveAccount(account)
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -133,11 +134,11 @@ func TestDeleteAccount(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer store.Close()
|
||||
defer store.Close(context.Background())
|
||||
|
||||
var account *Account
|
||||
for _, a := range store.Accounts {
|
||||
@ -147,7 +148,7 @@ func TestDeleteAccount(t *testing.T) {
|
||||
|
||||
require.NotNil(t, account, "failed to restore a FileStore file and get at least one account")
|
||||
|
||||
err = store.DeleteAccount(account)
|
||||
err = store.DeleteAccount(context.Background(), account)
|
||||
require.NoError(t, err, "failed to delete account, error: %v", err)
|
||||
|
||||
_, ok := store.Accounts[account.Id]
|
||||
@ -183,9 +184,9 @@ func TestDeleteAccount(t *testing.T) {
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
defer store.Close(context.Background())
|
||||
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
account := newAccountWithId(context.Background(), "account_id", "testuser", "")
|
||||
account.Peers["testpeer"] = &nbpeer.Peer{
|
||||
Key: "peerkey",
|
||||
SetupKey: "peerkeysetupkey",
|
||||
@ -228,12 +229,12 @@ func TestStore(t *testing.T) {
|
||||
})
|
||||
|
||||
// SaveAccount should trigger persist
|
||||
err := store.SaveAccount(account)
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
restored, err := NewFileStore(store.storeFile, nil)
|
||||
restored, err := NewFileStore(context.Background(), store.storeFile, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -281,7 +282,7 @@ func TestRestore(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -319,7 +320,7 @@ func TestRestoreGroups_Migration(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -332,11 +333,11 @@ func TestRestoreGroups_Migration(t *testing.T) {
|
||||
Name: "All",
|
||||
},
|
||||
}
|
||||
err = store.SaveAccount(account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err, "failed to save account")
|
||||
|
||||
// restore account with default group with empty Issue field
|
||||
if store, err = NewFileStore(storeDir, nil); err != nil {
|
||||
if store, err = NewFileStore(context.Background(), storeDir, nil); err != nil {
|
||||
return
|
||||
}
|
||||
account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
|
||||
@ -353,18 +354,18 @@ func TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
existingDomain := "test.com"
|
||||
|
||||
account, err := store.GetAccountByPrivateDomain(existingDomain)
|
||||
account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain)
|
||||
require.NoError(t, err, "should found account")
|
||||
require.Equal(t, existingDomain, account.Domain, "domains should match")
|
||||
|
||||
_, err = store.GetAccountByPrivateDomain("missing-domain.com")
|
||||
_, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com")
|
||||
require.Error(t, err, "should return error on domain lookup")
|
||||
}
|
||||
|
||||
@ -382,7 +383,7 @@ func TestFileStore_GetAccount(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -393,7 +394,7 @@ func TestFileStore_GetAccount(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := store.GetAccount(expected.Id)
|
||||
account, err := store.GetAccount(context.Background(), expected.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -424,13 +425,13 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
hashedToken := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].HashedToken
|
||||
tokenID, err := store.GetTokenIDByHashedToken(hashedToken)
|
||||
tokenID, err := store.GetTokenIDByHashedToken(context.Background(), hashedToken)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -441,7 +442,7 @@ func TestFileStore_GetTokenIDByHashedToken(t *testing.T) {
|
||||
|
||||
func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close()
|
||||
defer store.Close(context.Background())
|
||||
store.HashedPAT2TokenID["someHashedToken"] = "someTokenId"
|
||||
|
||||
err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken")
|
||||
@ -478,13 +479,13 @@ func TestFileStore_GetTokenIDByHashedToken_Failure(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wrongToken := sha256.Sum256([]byte("someNotValidTokenThatFails1234"))
|
||||
_, err = store.GetTokenIDByHashedToken(string(wrongToken[:]))
|
||||
_, err = store.GetTokenIDByHashedToken(context.Background(), string(wrongToken[:]))
|
||||
|
||||
assert.Error(t, err, "GetTokenIDByHashedToken should throw error if token invalid")
|
||||
}
|
||||
@ -503,13 +504,13 @@ func TestFileStore_GetUserByTokenID(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID
|
||||
user, err := store.GetUserByTokenID(tokenID)
|
||||
user, err := store.GetUserByTokenID(context.Background(), tokenID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -531,13 +532,13 @@ func TestFileStore_GetUserByTokenID_Failure(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wrongTokenID := "someNonExistingTokenID"
|
||||
_, err = store.GetUserByTokenID(wrongTokenID)
|
||||
_, err = store.GetUserByTokenID(context.Background(), wrongTokenID)
|
||||
|
||||
assert.Error(t, err, "GetUserByTokenID should throw error if tokenID invalid")
|
||||
}
|
||||
@ -550,7 +551,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -576,7 +577,7 @@ func TestFileStore_SavePeerStatus(t *testing.T) {
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -602,11 +603,11 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
store, err := NewFileStore(context.Background(), storeDir, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
account, err := store.GetAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
require.NoError(t, err)
|
||||
|
||||
peer := &nbpeer.Peer{
|
||||
@ -625,7 +626,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
|
||||
account.Peers[peer.ID] = peer
|
||||
err = store.SaveAccount(account)
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer.Location.ConnectionIP = net.ParseIP("35.1.1.1")
|
||||
@ -636,7 +637,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
|
||||
err = store.SavePeerLocation(account.Id, account.Peers[peer.ID])
|
||||
assert.NoError(t, err)
|
||||
|
||||
account, err = store.GetAccount(account.Id)
|
||||
account, err = store.GetAccount(context.Background(), account.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
actual := account.Peers[peer.ID].Location
|
||||
@ -645,7 +646,7 @@ func TestFileStore_SavePeerLocation(t *testing.T) {
|
||||
|
||||
func newStore(t *testing.T) *FileStore {
|
||||
t.Helper()
|
||||
store, err := NewFileStore(t.TempDir(), nil)
|
||||
store, err := NewFileStore(context.Background(), t.TempDir(), nil)
|
||||
if err != nil {
|
||||
t.Errorf("failed creating a new store")
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package geolocation
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
@ -52,7 +53,7 @@ type Country struct {
|
||||
CountryName string
|
||||
}
|
||||
|
||||
func NewGeolocation(dataDir string) (*Geolocation, error) {
|
||||
func NewGeolocation(ctx context.Context, dataDir string) (*Geolocation, error) {
|
||||
if err := loadGeolocationDatabases(dataDir); err != nil {
|
||||
return nil, fmt.Errorf("failed to load MaxMind databases: %v", err)
|
||||
}
|
||||
@ -68,7 +69,7 @@ func NewGeolocation(dataDir string) (*Geolocation, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
locationDB, err := NewSqliteStore(dataDir)
|
||||
locationDB, err := NewSqliteStore(ctx, dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -83,7 +84,7 @@ func NewGeolocation(dataDir string) (*Geolocation, error) {
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
go geo.reloader()
|
||||
go geo.reloader(ctx)
|
||||
|
||||
return geo, nil
|
||||
}
|
||||
@ -165,19 +166,19 @@ func (gl *Geolocation) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (gl *Geolocation) reloader() {
|
||||
func (gl *Geolocation) reloader(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-gl.stopCh:
|
||||
return
|
||||
case <-time.After(gl.reloadCheckInterval):
|
||||
if err := gl.locationDB.reload(); err != nil {
|
||||
log.Errorf("geonames db reload failed: %s", err)
|
||||
if err := gl.locationDB.reload(ctx); err != nil {
|
||||
log.WithContext(ctx).Errorf("geonames db reload failed: %s", err)
|
||||
}
|
||||
|
||||
newSha256sum1, err := calculateFileSHA256(gl.mmdbPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
|
||||
log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
|
||||
continue
|
||||
}
|
||||
if !bytes.Equal(gl.sha256sum, newSha256sum1) {
|
||||
@ -186,30 +187,30 @@ func (gl *Geolocation) reloader() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
newSha256sum2, err := calculateFileSHA256(gl.mmdbPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
|
||||
log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
|
||||
continue
|
||||
}
|
||||
if !bytes.Equal(newSha256sum1, newSha256sum2) {
|
||||
log.Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath)
|
||||
log.WithContext(ctx).Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath)
|
||||
continue
|
||||
}
|
||||
err = gl.reload(newSha256sum2)
|
||||
err = gl.reload(ctx, newSha256sum2)
|
||||
if err != nil {
|
||||
log.Errorf("mmdb reload failed: %s", err)
|
||||
log.WithContext(ctx).Errorf("mmdb reload failed: %s", err)
|
||||
}
|
||||
} else {
|
||||
log.Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.",
|
||||
log.WithContext(ctx).Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.",
|
||||
gl.mmdbPath, gl.reloadCheckInterval.Seconds())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (gl *Geolocation) reload(newSha256sum []byte) error {
|
||||
func (gl *Geolocation) reload(ctx context.Context, newSha256sum []byte) error {
|
||||
gl.mux.Lock()
|
||||
defer gl.mux.Unlock()
|
||||
|
||||
log.Infof("Reloading '%s'", gl.mmdbPath)
|
||||
log.WithContext(ctx).Infof("Reloading '%s'", gl.mmdbPath)
|
||||
|
||||
err := gl.db.Close()
|
||||
if err != nil {
|
||||
@ -224,7 +225,7 @@ func (gl *Geolocation) reload(newSha256sum []byte) error {
|
||||
gl.db = db
|
||||
gl.sha256sum = newSha256sum
|
||||
|
||||
log.Infof("Successfully reloaded '%s'", gl.mmdbPath)
|
||||
log.WithContext(ctx).Infof("Successfully reloaded '%s'", gl.mmdbPath)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package geolocation
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@ -50,10 +51,10 @@ type SqliteStore struct {
|
||||
sha256sum []byte
|
||||
}
|
||||
|
||||
func NewSqliteStore(dataDir string) (*SqliteStore, error) {
|
||||
func NewSqliteStore(ctx context.Context, dataDir string) (*SqliteStore, error) {
|
||||
file := filepath.Join(dataDir, GeoSqliteDBFile)
|
||||
|
||||
db, err := connectDB(file)
|
||||
db, err := connectDB(ctx, file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -115,13 +116,13 @@ func (s *SqliteStore) GetCitiesByCountry(countryISOCode string) ([]City, error)
|
||||
}
|
||||
|
||||
// reload attempts to reload the SqliteStore's database if the database file has changed.
|
||||
func (s *SqliteStore) reload() error {
|
||||
func (s *SqliteStore) reload(ctx context.Context) error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
newSha256sum1, err := calculateFileSHA256(s.filePath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
|
||||
log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(s.sha256sum, newSha256sum1) {
|
||||
@ -136,11 +137,11 @@ func (s *SqliteStore) reload() error {
|
||||
return fmt.Errorf("sha256 sum changed during reloading of '%s'", s.filePath)
|
||||
}
|
||||
|
||||
log.Infof("Reloading '%s'", s.filePath)
|
||||
log.WithContext(ctx).Infof("Reloading '%s'", s.filePath)
|
||||
_ = s.close()
|
||||
s.closed = true
|
||||
|
||||
newDb, err := connectDB(s.filePath)
|
||||
newDb, err := connectDB(ctx, s.filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -148,9 +149,9 @@ func (s *SqliteStore) reload() error {
|
||||
s.closed = false
|
||||
s.db = newDb
|
||||
|
||||
log.Infof("Successfully reloaded '%s'", s.filePath)
|
||||
log.WithContext(ctx).Infof("Successfully reloaded '%s'", s.filePath)
|
||||
} else {
|
||||
log.Tracef("No changes in '%s', no need to reload", s.filePath)
|
||||
log.WithContext(ctx).Tracef("No changes in '%s', no need to reload", s.filePath)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -168,10 +169,10 @@ func (s *SqliteStore) close() error {
|
||||
}
|
||||
|
||||
// connectDB connects to an SQLite database and prepares it by setting up an in-memory database.
|
||||
func connectDB(filePath string) (*gorm.DB, error) {
|
||||
func connectDB(ctx context.Context, filePath string) (*gorm.DB, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.Debugf("took %v to setup geoname db", time.Since(start))
|
||||
log.WithContext(ctx).Debugf("took %v to setup geoname db", time.Since(start))
|
||||
}()
|
||||
|
||||
_, err := fileExists(filePath)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/rs/xid"
|
||||
@ -21,11 +22,11 @@ func (e *GroupLinkError) Error() string {
|
||||
}
|
||||
|
||||
// GetGroup object of the peers
|
||||
func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -48,11 +49,11 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*n
|
||||
}
|
||||
|
||||
// GetAllGroups returns all groups in an account
|
||||
func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -75,11 +76,11 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) (
|
||||
}
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -108,11 +109,11 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*n
|
||||
}
|
||||
|
||||
// SaveGroup object of the peers
|
||||
func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -150,11 +151,11 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *n
|
||||
account.Groups[newGroup.ID] = newGroup
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
// the following snippet tracks the activity and stores the group events in the event store.
|
||||
// It has to happen after all the operations have been successfully performed.
|
||||
@ -165,16 +166,16 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *n
|
||||
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
||||
} else {
|
||||
addedPeers = append(addedPeers, newGroup.Peers...)
|
||||
am.StoreEvent(userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
|
||||
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupCreated, newGroup.EventMeta())
|
||||
}
|
||||
|
||||
for _, p := range addedPeers {
|
||||
peer := account.Peers[p]
|
||||
if peer == nil {
|
||||
log.Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
continue
|
||||
}
|
||||
am.StoreEvent(userID, peer.ID, accountID, activity.GroupAddedToPeer,
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer,
|
||||
map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
|
||||
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
@ -184,10 +185,10 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *n
|
||||
for _, p := range removedPeers {
|
||||
peer := account.Peers[p]
|
||||
if peer == nil {
|
||||
log.Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
|
||||
continue
|
||||
}
|
||||
am.StoreEvent(userID, peer.ID, accountID, activity.GroupRemovedFromPeer,
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer,
|
||||
map[string]any{
|
||||
"group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peer.IP.String(),
|
||||
"peer_fqdn": peer.FQDN(am.GetDNSDomain()),
|
||||
@ -213,11 +214,11 @@ func difference(a, b []string) []string {
|
||||
}
|
||||
|
||||
// DeleteGroup object of the peers
|
||||
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountId)
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountId)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountId)
|
||||
account, err := am.Store.GetAccount(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -315,23 +316,23 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string)
|
||||
delete(account.Groups, groupID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListGroups objects of the peers
|
||||
func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -345,11 +346,11 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group,
|
||||
}
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -371,21 +372,21 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -399,13 +400,13 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID stri
|
||||
for i, itemID := range group.Peers {
|
||||
if itemID == peerID {
|
||||
group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
|
||||
if err := am.Store.SaveAccount(account); err != nil {
|
||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
@ -26,7 +27,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
}
|
||||
for _, group := range account.Groups {
|
||||
group.Issued = nbgroup.GroupIssuedIntegration
|
||||
err = am.SaveGroup(account.Id, groupAdminUserID, group)
|
||||
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
|
||||
if err != nil {
|
||||
t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration)
|
||||
}
|
||||
@ -34,7 +35,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
|
||||
for _, group := range account.Groups {
|
||||
group.Issued = nbgroup.GroupIssuedJWT
|
||||
err = am.SaveGroup(account.Id, groupAdminUserID, group)
|
||||
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
|
||||
if err != nil {
|
||||
t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT)
|
||||
}
|
||||
@ -42,7 +43,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
for _, group := range account.Groups {
|
||||
group.Issued = nbgroup.GroupIssuedAPI
|
||||
group.ID = ""
|
||||
err = am.SaveGroup(account.Id, groupAdminUserID, group)
|
||||
err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group)
|
||||
if err == nil {
|
||||
t.Errorf("should not create api group with the same name, %s", group.Name)
|
||||
}
|
||||
@ -104,7 +105,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
err = am.DeleteGroup(account.Id, groupAdminUserID, testCase.groupID)
|
||||
err = am.DeleteGroup(context.Background(), account.Id, groupAdminUserID, testCase.groupID)
|
||||
if err == nil {
|
||||
t.Errorf("delete %s group successfully", testCase.groupID)
|
||||
return
|
||||
@ -225,7 +226,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
Id: "example user",
|
||||
AutoGroups: []string{groupForUsers.ID},
|
||||
}
|
||||
account := newAccountWithId(accountID, groupAdminUserID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain)
|
||||
account.Routes[routeResource.ID] = routeResource
|
||||
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
|
||||
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
|
||||
@ -233,18 +234,18 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
account.SetupKeys[setupKey.Id] = setupKey
|
||||
account.Users[user.Id] = user
|
||||
|
||||
err := am.Store.SaveAccount(account)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute2)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForNameServerGroups)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForPolicies)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForSetupKeys)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForUsers)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForIntegration)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
|
||||
|
||||
return am.Store.GetAccount(account.Id)
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
}
|
||||
|
@ -11,12 +11,14 @@ import (
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@ -40,7 +42,7 @@ type GRPCServer struct {
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
func NewServer(config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
|
||||
func NewServer(ctx context.Context, config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -50,6 +52,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
|
||||
|
||||
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
|
||||
jwtValidator, err = jwtclaims.NewJWTValidator(
|
||||
ctx,
|
||||
config.HttpConfig.AuthIssuer,
|
||||
config.GetAuthAudiences(),
|
||||
config.HttpConfig.AuthKeysLocation,
|
||||
@ -59,7 +62,7 @@ func NewServer(config *Config, accountManager AccountManager, peersUpdateManager
|
||||
return nil, status.Errorf(codes.Internal, "unable to create new jwt middleware, err: %v", err)
|
||||
}
|
||||
} else {
|
||||
log.Debug("unable to use http config to create new jwt middleware")
|
||||
log.WithContext(ctx).Debug("unable to use http config to create new jwt middleware")
|
||||
}
|
||||
|
||||
if appMetrics != nil {
|
||||
@ -126,47 +129,61 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequest()
|
||||
}
|
||||
realIP := getRealIP(srv.Context())
|
||||
log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
|
||||
ctx := srv.Context()
|
||||
|
||||
realIP := getRealIP(ctx)
|
||||
|
||||
syncReq := &proto.SyncRequest{}
|
||||
peerKey, err := s.parseRequest(req, syncReq)
|
||||
peerKey, err := s.parseRequest(ctx, req, syncReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||
if err != nil {
|
||||
// this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail
|
||||
accountID = "UNKNOWN"
|
||||
}
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
|
||||
if syncReq.GetMeta() == nil {
|
||||
log.Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
}
|
||||
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), extractPeerMeta(syncReq.GetMeta()), realIP)
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP)
|
||||
if err != nil {
|
||||
return mapError(err)
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
err = s.sendInitialSync(peerKey, peer, netMap, postureChecks, srv)
|
||||
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv)
|
||||
if err != nil {
|
||||
log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||
return err
|
||||
}
|
||||
|
||||
updates := s.peersUpdateManager.CreateChannel(peer.ID)
|
||||
updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID)
|
||||
|
||||
s.ephemeralManager.OnPeerConnected(peer)
|
||||
s.ephemeralManager.OnPeerConnected(ctx, peer)
|
||||
|
||||
if s.config.TURNConfig.TimeBasedCredentials {
|
||||
s.turnCredentialsManager.SetupRefresh(peer.ID)
|
||||
s.turnCredentialsManager.SetupRefresh(ctx, peer.ID)
|
||||
}
|
||||
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
|
||||
}
|
||||
|
||||
return s.handleUpdates(peerKey, peer, updates, srv)
|
||||
return s.handleUpdates(ctx, peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
// handleUpdates sends updates to the connected peer until the updates channel is closed.
|
||||
func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *GRPCServer) handleUpdates(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
for {
|
||||
select {
|
||||
// condition when there are some updates
|
||||
@ -176,21 +193,21 @@ func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updat
|
||||
}
|
||||
|
||||
if !open {
|
||||
log.Debugf("updates channel for peer %s was closed", peerKey.String())
|
||||
s.cancelPeerRoutines(peer)
|
||||
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
return nil
|
||||
}
|
||||
log.Debugf("received an update for peer %s", peerKey.String())
|
||||
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
|
||||
|
||||
if err := s.sendUpdate(peerKey, peer, update, srv); err != nil {
|
||||
if err := s.sendUpdate(ctx, peerKey, peer, update, srv); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// condition when client <-> server connection has been terminated
|
||||
case <-srv.Context().Done():
|
||||
// happens when connection drops, e.g. client disconnects
|
||||
log.Debugf("stream of peer %s has been closed", peerKey.String())
|
||||
s.cancelPeerRoutines(peer)
|
||||
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
return srv.Context().Err()
|
||||
}
|
||||
}
|
||||
@ -198,10 +215,10 @@ func (s *GRPCServer) handleUpdates(peerKey wgtypes.Key, peer *nbpeer.Peer, updat
|
||||
|
||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||
// then sends the encrypted message to the connected peer via the sync server.
|
||||
func (s *GRPCServer) sendUpdate(peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *GRPCServer) sendUpdate(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(peer)
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
return status.Errorf(codes.Internal, "failed processing update message")
|
||||
}
|
||||
err = srv.SendMsg(&proto.EncryptedMessage{
|
||||
@ -209,37 +226,37 @@ func (s *GRPCServer) sendUpdate(peerKey wgtypes.Key, peer *nbpeer.Peer, update *
|
||||
Body: encryptedResp,
|
||||
})
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(peer)
|
||||
s.cancelPeerRoutines(ctx, peer)
|
||||
return status.Errorf(codes.Internal, "failed sending update message")
|
||||
}
|
||||
log.Debugf("sent an update to peer %s", peerKey.String())
|
||||
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
|
||||
s.peersUpdateManager.CloseChannel(peer.ID)
|
||||
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, peer *nbpeer.Peer) {
|
||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
s.turnCredentialsManager.CancelRefresh(peer.ID)
|
||||
_ = s.accountManager.CancelPeerRoutines(peer)
|
||||
s.ephemeralManager.OnPeerDisconnected(peer)
|
||||
_ = s.accountManager.CancelPeerRoutines(ctx, peer)
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
}
|
||||
|
||||
func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
|
||||
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||
if s.jwtValidator == nil {
|
||||
return "", status.Error(codes.Internal, "no jwt validator set")
|
||||
}
|
||||
|
||||
token, err := s.jwtValidator.ValidateAndParse(jwtToken)
|
||||
token, err := s.jwtValidator.ValidateAndParse(ctx, jwtToken)
|
||||
if err != nil {
|
||||
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
|
||||
}
|
||||
claims := s.jwtClaimsExtractor.FromToken(token)
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
_, _, err = s.accountManager.GetAccountFromToken(claims)
|
||||
_, _, err = s.accountManager.GetAccountFromToken(ctx, claims)
|
||||
if err != nil {
|
||||
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
|
||||
}
|
||||
|
||||
if err := s.accountManager.CheckUserAccessByJWTGroups(claims); err != nil {
|
||||
if err := s.accountManager.CheckUserAccessByJWTGroups(ctx, claims); err != nil {
|
||||
return "", status.Errorf(codes.PermissionDenied, err.Error())
|
||||
}
|
||||
|
||||
@ -247,7 +264,7 @@ func (s *GRPCServer) validateToken(jwtToken string) (string, error) {
|
||||
}
|
||||
|
||||
// maps internal internalStatus.Error to gRPC status.Error
|
||||
func mapError(err error) error {
|
||||
func mapError(ctx context.Context, err error) error {
|
||||
if e, ok := internalStatus.FromError(err); ok {
|
||||
switch e.Type() {
|
||||
case internalStatus.PermissionDenied:
|
||||
@ -263,11 +280,11 @@ func mapError(err error) error {
|
||||
default:
|
||||
}
|
||||
}
|
||||
log.Errorf("got an unhandled error: %s", err)
|
||||
log.WithContext(ctx).Errorf("got an unhandled error: %s", err)
|
||||
return status.Errorf(codes.Internal, "failed handling request")
|
||||
}
|
||||
|
||||
func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
|
||||
func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
|
||||
if meta == nil {
|
||||
return nbpeer.PeerSystemMeta{}
|
||||
}
|
||||
@ -281,7 +298,7 @@ func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
|
||||
for _, addr := range meta.GetNetworkAddresses() {
|
||||
netAddr, err := netip.ParsePrefix(addr.GetNetIP())
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err)
|
||||
log.WithContext(ctx).Warnf("failed to parse netip address, %s: %v", addr.GetNetIP(), err)
|
||||
continue
|
||||
}
|
||||
networkAddresses = append(networkAddresses, nbpeer.NetworkAddress{
|
||||
@ -321,10 +338,10 @@ func extractPeerMeta(meta *proto.PeerSystemMeta) nbpeer.PeerSystemMeta {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GRPCServer) parseRequest(req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
|
||||
func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) {
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
if err != nil {
|
||||
log.Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
|
||||
log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey)
|
||||
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
|
||||
}
|
||||
|
||||
@ -351,22 +368,32 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequest()
|
||||
}
|
||||
realIP := getRealIP(ctx)
|
||||
log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
|
||||
loginReq := &proto.LoginRequest{}
|
||||
peerKey, err := s.parseRequest(req, loginReq)
|
||||
peerKey, err := s.parseRequest(ctx, req, loginReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||
if err != nil {
|
||||
// this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail
|
||||
accountID = "UNKNOWN"
|
||||
}
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
|
||||
if loginReq.GetMeta() == nil {
|
||||
msg := status.Errorf(codes.FailedPrecondition,
|
||||
"peer system meta has to be provided to log in. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
log.Warn(msg)
|
||||
log.WithContext(ctx).Warn(msg)
|
||||
return nil, msg
|
||||
}
|
||||
|
||||
userID, err := s.processJwtToken(loginReq, peerKey)
|
||||
userID, err := s.processJwtToken(ctx, loginReq, peerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -376,33 +403,33 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
|
||||
}
|
||||
|
||||
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(PeerLogin{
|
||||
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, PeerLogin{
|
||||
WireGuardPubKey: peerKey.String(),
|
||||
SSHKey: string(sshKey),
|
||||
Meta: extractPeerMeta(loginReq.GetMeta()),
|
||||
Meta: extractPeerMeta(ctx, loginReq.GetMeta()),
|
||||
UserID: userID,
|
||||
SetupKey: loginReq.GetSetupKey(),
|
||||
ConnectionIP: realIP,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warnf("failed logging in peer %s: %s", peerKey, err)
|
||||
return nil, mapError(err)
|
||||
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
// if the login request contains setup key then it is a registration request
|
||||
if loginReq.GetSetupKey() != "" {
|
||||
s.ephemeralManager.OnPeerDisconnected(peer)
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
}
|
||||
|
||||
// if peer has reached this point then it has logged in
|
||||
loginResp := &proto.LoginResponse{
|
||||
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
|
||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
|
||||
Checks: toProtocolChecks(postureChecks),
|
||||
Checks: toProtocolChecks(ctx, postureChecks),
|
||||
}
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
|
||||
if err != nil {
|
||||
log.Warnf("failed encrypting peer %s message", peer.ID)
|
||||
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
|
||||
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
||||
}
|
||||
|
||||
@ -417,16 +444,16 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
//
|
||||
// The user ID can be empty if the token is not provided, which is acceptable if the peer is already
|
||||
// registered or if it uses a setup key to register.
|
||||
func (s *GRPCServer) processJwtToken(loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
|
||||
func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) {
|
||||
userID := ""
|
||||
if loginReq.GetJwtToken() != "" {
|
||||
var err error
|
||||
for i := 0; i < 3; i++ {
|
||||
userID, err = s.validateToken(loginReq.GetJwtToken())
|
||||
userID, err = s.validateToken(ctx, loginReq.GetJwtToken())
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
log.Warnf("failed validating JWT token sent from peer %s with error %v. "+
|
||||
log.WithContext(ctx).Warnf("failed validating JWT token sent from peer %s with error %v. "+
|
||||
"Trying again as it may be due to the IdP cache issue", peerKey.String(), err)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
@ -520,7 +547,7 @@ func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePee
|
||||
return remotePeers
|
||||
}
|
||||
|
||||
func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
|
||||
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
|
||||
wtConfig := toWiretrusteeConfig(config, turnCredentials)
|
||||
|
||||
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
|
||||
@ -551,7 +578,7 @@ func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCred
|
||||
FirewallRules: firewallRules,
|
||||
FirewallRulesIsEmpty: len(firewallRules) == 0,
|
||||
},
|
||||
Checks: toProtocolChecks(checks),
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
}
|
||||
|
||||
@ -561,7 +588,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
|
||||
}
|
||||
|
||||
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
|
||||
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
|
||||
func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
|
||||
// make secret time based TURN credentials optional
|
||||
var turnCredentials *TURNCredentials
|
||||
if s.config.TURNConfig.TimeBasedCredentials {
|
||||
@ -570,7 +597,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net
|
||||
} else {
|
||||
turnCredentials = nil
|
||||
}
|
||||
plainResp := toSyncResponse(s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
|
||||
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||
if err != nil {
|
||||
@ -583,7 +610,7 @@ func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, net
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("failed sending SyncResponse %v", err)
|
||||
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
|
||||
return status.Errorf(codes.Internal, "error handling request")
|
||||
}
|
||||
|
||||
@ -597,14 +624,14 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
if err != nil {
|
||||
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetDeviceAuthorizationFlow request.", req.WgPubKey)
|
||||
log.Warn(errMSG)
|
||||
log.WithContext(ctx).Warn(errMSG)
|
||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||
}
|
||||
|
||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{})
|
||||
if err != nil {
|
||||
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
|
||||
log.Warn(errMSG)
|
||||
log.WithContext(ctx).Warn(errMSG)
|
||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||
}
|
||||
|
||||
@ -645,18 +672,18 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.
|
||||
// GetPKCEAuthorizationFlow returns a pkce authorization flow information
|
||||
// This is used for initiating an Oauth 2 pkce authorization grant flow
|
||||
// which will be used by our clients to Login
|
||||
func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
peerKey, err := wgtypes.ParseKey(req.GetWgPubKey())
|
||||
if err != nil {
|
||||
errMSG := fmt.Sprintf("error while parsing peer's Wireguard public key %s on GetPKCEAuthorizationFlow request.", req.WgPubKey)
|
||||
log.Warn(errMSG)
|
||||
log.WithContext(ctx).Warn(errMSG)
|
||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||
}
|
||||
|
||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{})
|
||||
if err != nil {
|
||||
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
|
||||
log.Warn(errMSG)
|
||||
log.WithContext(ctx).Warn(errMSG)
|
||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||
}
|
||||
|
||||
@ -692,10 +719,10 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(_ context.Context, req *proto.Encr
|
||||
// peer's under the same account of any updates.
|
||||
func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||
realIP := getRealIP(ctx)
|
||||
log.Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
|
||||
syncMetaReq := &proto.SyncMetaRequest{}
|
||||
peerKey, err := s.parseRequest(req, syncMetaReq)
|
||||
peerKey, err := s.parseRequest(ctx, req, syncMetaReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -703,20 +730,20 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
|
||||
if syncMetaReq.GetMeta() == nil {
|
||||
msg := status.Errorf(codes.FailedPrecondition,
|
||||
"peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
log.Warn(msg)
|
||||
log.WithContext(ctx).Warn(msg)
|
||||
return nil, msg
|
||||
}
|
||||
|
||||
err = s.accountManager.SyncPeerMeta(peerKey.String(), extractPeerMeta(syncMetaReq.GetMeta()))
|
||||
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()))
|
||||
if err != nil {
|
||||
return nil, mapError(err)
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
// toProtocolChecks converts posture checks to protocol checks.
|
||||
func toProtocolChecks(postureChecks []*posture.Checks) []*proto.Checks {
|
||||
func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks {
|
||||
protoChecks := make([]*proto.Checks, 0, len(postureChecks))
|
||||
for _, postureCheck := range postureChecks {
|
||||
protoChecks = append(protoChecks, toProtocolCheck(postureCheck))
|
||||
|
@ -35,34 +35,34 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
|
||||
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
|
||||
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if !(user.HasAdminPower() || user.IsServiceUser) {
|
||||
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(account)
|
||||
util.WriteJSONObject(w, []*api.Account{resp})
|
||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||
}
|
||||
|
||||
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
||||
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
_, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
_, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
accountID := vars["accountId"]
|
||||
if len(accountID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid accountID ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -96,15 +96,15 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
|
||||
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
|
||||
}
|
||||
|
||||
updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings)
|
||||
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(updatedAccount)
|
||||
|
||||
util.WriteJSONObject(w, &resp)
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
|
||||
// DeleteAccount is a HTTP DELETE handler to delete an account
|
||||
@ -118,17 +118,17 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
|
||||
vars := mux.Vars(r)
|
||||
targetAccountID := vars["accountId"]
|
||||
if len(targetAccountID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid account ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid account ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err := h.accountManager.DeleteAccount(targetAccountID, claims.UserId)
|
||||
err := h.accountManager.DeleteAccount(r.Context(), targetAccountID, claims.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
func toAccountResponse(account *server.Account) *api.Account {
|
||||
|
@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -22,10 +23,10 @@ import (
|
||||
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
|
||||
return &AccountsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return account, admin, nil
|
||||
},
|
||||
UpdateAccountSettingsFunc: func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
||||
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
||||
halfYearLimit := 180 * 24 * time.Hour
|
||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
||||
|
@ -32,16 +32,16 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
|
||||
// GetDNSSettings returns the DNS settings for the account
|
||||
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
dnsSettings, err := h.accountManager.GetDNSSettings(account.Id, user.Id)
|
||||
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -49,15 +49,15 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
|
||||
DisabledManagementGroups: dnsSettings.DisabledManagementGroups,
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, apiDNSSettings)
|
||||
util.WriteJSONObject(r.Context(), w, apiDNSSettings)
|
||||
}
|
||||
|
||||
// UpdateDNSSettings handles update to DNS settings of an account
|
||||
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -72,9 +72,9 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
|
||||
DisabledManagementGroups: req.DisabledManagementGroups,
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveDNSSettings(account.Id, user.Id, updateDNSSettings)
|
||||
err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -82,5 +82,5 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
|
||||
DisabledManagementGroups: updateDNSSettings.DisabledManagementGroups,
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, &resp)
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -42,16 +43,16 @@ var testingDNSSettingsAccount = &server.Account{
|
||||
func initDNSSettingsTestData() *DNSSettingsHandler {
|
||||
return &DNSSettingsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetDNSSettingsFunc: func(accountID string, userID string) (*server.DNSSettings, error) {
|
||||
GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) {
|
||||
return &testingDNSSettingsAccount.DNSSettings, nil
|
||||
},
|
||||
SaveDNSSettingsFunc: func(accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
|
||||
SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
|
||||
if dnsSettingsToSave != nil {
|
||||
return nil
|
||||
}
|
||||
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil
|
||||
},
|
||||
},
|
||||
|
@ -1,6 +1,7 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
@ -33,16 +34,16 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev
|
||||
// GetAllEvents list of the given account
|
||||
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accountEvents, err := h.accountManager.GetEvents(account.Id, user.Id)
|
||||
accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
events := make([]*api.Event, len(accountEvents))
|
||||
@ -50,20 +51,20 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
||||
events[i] = toEventResponse(e)
|
||||
}
|
||||
|
||||
err = h.fillEventsWithUserInfo(events, account.Id, user.Id)
|
||||
err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, events)
|
||||
util.WriteJSONObject(r.Context(), w, events)
|
||||
}
|
||||
|
||||
func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, userId string) error {
|
||||
func (h *EventsHandler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error {
|
||||
// build email, name maps based on users
|
||||
userInfos, err := h.accountManager.GetUsersFromAccount(accountId, userId)
|
||||
userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get users from account: %s", err)
|
||||
log.WithContext(ctx).Errorf("failed to get users from account: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -80,7 +81,7 @@ func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, u
|
||||
if event.InitiatorEmail == "" {
|
||||
event.InitiatorEmail, ok = emails[event.InitiatorId]
|
||||
if !ok {
|
||||
log.Warnf("failed to resolve email for initiator: %s", event.InitiatorId)
|
||||
log.WithContext(ctx).Warnf("failed to resolve email for initiator: %s", event.InitiatorId)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -22,13 +23,13 @@ import (
|
||||
func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler {
|
||||
return &EventsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetEventsFunc: func(accountID, userID string) ([]*activity.Event, error) {
|
||||
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
|
||||
if accountID == account {
|
||||
return events, nil
|
||||
}
|
||||
return []*activity.Event{}, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
@ -37,7 +38,7 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
|
||||
},
|
||||
}, user, nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
||||
return make([]*server.UserInfo, 0), nil
|
||||
},
|
||||
},
|
||||
|
@ -1,6 +1,7 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -35,13 +36,13 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
|
||||
err = util.CopyFileContents(geonamesDBPath, path.Join(tempDir, geolocation.GeoSqliteDBFile))
|
||||
assert.NoError(t, err)
|
||||
|
||||
geo, err := geolocation.NewGeolocation(tempDir)
|
||||
geo, err := geolocation.NewGeolocation(context.Background(), tempDir)
|
||||
assert.NoError(t, err)
|
||||
t.Cleanup(func() { _ = geo.Stop() })
|
||||
|
||||
return &GeolocationsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
|
@ -40,19 +40,19 @@ func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geoloca
|
||||
// GetAllCountries retrieves a list of all countries
|
||||
func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Request) {
|
||||
if err := l.authenticateUser(r); err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if l.geolocationManager == nil {
|
||||
// TODO: update error message to include geo db self hosted doc link when ready
|
||||
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
|
||||
return
|
||||
}
|
||||
|
||||
allCountries, err := l.geolocationManager.GetAllCountries()
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -60,32 +60,32 @@ func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Req
|
||||
for _, country := range allCountries {
|
||||
countries = append(countries, toCountryResponse(country))
|
||||
}
|
||||
util.WriteJSONObject(w, countries)
|
||||
util.WriteJSONObject(r.Context(), w, countries)
|
||||
}
|
||||
|
||||
// GetCitiesByCountry retrieves a list of cities based on the given country code
|
||||
func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.Request) {
|
||||
if err := l.authenticateUser(r); err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
countryCode := vars["country"]
|
||||
if !countryCodeRegex.MatchString(countryCode) {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid country code"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid country code"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if l.geolocationManager == nil {
|
||||
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
|
||||
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
|
||||
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
|
||||
return
|
||||
}
|
||||
|
||||
allCities, err := l.geolocationManager.GetCitiesByCountry(countryCode)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -93,12 +93,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
|
||||
for _, city := range allCities {
|
||||
cities = append(cities, toCityResponse(city))
|
||||
}
|
||||
util.WriteJSONObject(w, cities)
|
||||
util.WriteJSONObject(r.Context(), w, cities)
|
||||
}
|
||||
|
||||
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
|
||||
claims := l.claimsExtractor.FromRequestContext(r)
|
||||
_, user, err := l.accountManager.GetAccountFromToken(claims)
|
||||
_, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -35,16 +35,16 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr
|
||||
// GetAllGroups list for the account
|
||||
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
groups, err := h.accountManager.GetAllGroups(account.Id, user.Id)
|
||||
groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -53,42 +53,42 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
groupsResponse = append(groupsResponse, toGroupResponse(account, group))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, groupsResponse)
|
||||
util.WriteJSONObject(r.Context(), w, groupsResponse)
|
||||
}
|
||||
|
||||
// UpdateGroup handles update to a group identified by a given ID
|
||||
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
groupID, ok := vars["groupId"]
|
||||
if !ok {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group ID field is missing"), w)
|
||||
return
|
||||
}
|
||||
if len(groupID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "group ID can't be empty"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group ID can't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
eg, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
if allGroup.ID == groupID {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -100,7 +100,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -118,21 +118,21 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
IntegrationReference: eg.IntegrationReference,
|
||||
}
|
||||
|
||||
if err := h.accountManager.SaveGroup(account.Id, user.Id, &group); err != nil {
|
||||
log.Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
|
||||
util.WriteError(err, w)
|
||||
if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, toGroupResponse(account, &group))
|
||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
|
||||
}
|
||||
|
||||
// CreateGroup handles group creation request
|
||||
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -144,7 +144,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "group name shouldn't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -160,62 +160,62 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
Issued: nbgroup.GroupIssuedAPI,
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveGroup(account.Id, user.Id, &group)
|
||||
err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, toGroupResponse(account, &group))
|
||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
|
||||
}
|
||||
|
||||
// DeleteGroup handles group deletion request
|
||||
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
aID := account.Id
|
||||
|
||||
groupID := mux.Vars(r)["groupId"]
|
||||
if len(groupID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if allGroup.ID == groupID {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteGroup(aID, user.Id, groupID)
|
||||
err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID)
|
||||
if err != nil {
|
||||
_, ok := err.(*server.GroupLinkError)
|
||||
if ok {
|
||||
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
// GetGroup returns a group
|
||||
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -223,19 +223,19 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
|
||||
case http.MethodGet:
|
||||
groupID := mux.Vars(r)["groupId"]
|
||||
if len(groupID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.accountManager.GetGroup(account.Id, groupID, user.Id)
|
||||
group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, toGroupResponse(account, group))
|
||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group))
|
||||
default:
|
||||
util.WriteError(status.Errorf(status.NotFound, "HTTP method not found"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -32,13 +33,13 @@ var TestPeers = map[string]*nbpeer.Peer{
|
||||
func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
|
||||
return &GroupsHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
SaveGroupFunc: func(accountID, userID string, group *nbgroup.Group) error {
|
||||
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
|
||||
if !strings.HasPrefix(group.ID, "id-") {
|
||||
group.ID = "id-was-set"
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetGroupFunc: func(_, groupID, _ string) (*nbgroup.Group, error) {
|
||||
GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) {
|
||||
if groupID != "idofthegroup" {
|
||||
return nil, status.Errorf(status.NotFound, "not found")
|
||||
}
|
||||
@ -55,7 +56,7 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
|
||||
Issued: nbgroup.GroupIssuedAPI,
|
||||
}, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
Domain: "hotmail.com",
|
||||
@ -70,7 +71,7 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
|
||||
},
|
||||
}, user, nil
|
||||
},
|
||||
DeleteGroupFunc: func(accountID, userId, groupID string) error {
|
||||
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
|
||||
if groupID == "linked-grp" {
|
||||
return &server.GroupLinkError{
|
||||
Resource: "something",
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"github.com/rs/cors"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
s "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
@ -57,6 +58,11 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
|
||||
|
||||
corsMiddleware := cors.AllowAll()
|
||||
|
||||
claimsExtractor = jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
)
|
||||
|
||||
acMiddleware := middleware.NewAccessControl(
|
||||
authCfg.Audience,
|
||||
authCfg.UserIDClaim,
|
||||
|
@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
||||
@ -15,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims
|
||||
type GetUser func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
|
||||
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
|
||||
type AccessControl struct {
|
||||
@ -46,15 +47,15 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
||||
|
||||
claims := a.claimsExtract.FromRequestContext(r)
|
||||
|
||||
user, err := a.getUser(claims)
|
||||
user, err := a.getUser(r.Context(), claims)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get user from claims: %s", err)
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||
log.WithContext(r.Context()).Errorf("failed to get user from claims: %s", err)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "invalid JWT"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if user.IsBlocked() {
|
||||
util.WriteError(status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -63,12 +64,12 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
|
||||
case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut:
|
||||
|
||||
if tokenPathRegexp.MatchString(r.URL.Path) {
|
||||
log.Debugf("valid Path")
|
||||
log.WithContext(r.Context()).Debugf("valid Path")
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteError(status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "only users with admin power can perform this operation"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
@ -19,16 +20,16 @@ import (
|
||||
)
|
||||
|
||||
// GetAccountFromPATFunc function
|
||||
type GetAccountFromPATFunc func(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
|
||||
// ValidateAndParseTokenFunc function
|
||||
type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
|
||||
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
|
||||
|
||||
// MarkPATUsedFunc function
|
||||
type MarkPATUsedFunc func(token string) error
|
||||
type MarkPATUsedFunc func(ctx context.Context, token string) error
|
||||
|
||||
// CheckUserAccessByJWTGroupsFunc function
|
||||
type CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
|
||||
type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
@ -85,23 +86,27 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
case "bearer":
|
||||
err := m.checkJWTFromRequest(w, r, auth)
|
||||
if err != nil {
|
||||
log.Errorf("Error when validating JWT claims: %s", err.Error())
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
log.WithContext(r.Context()).Errorf("Error when validating JWT claims: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
case "token":
|
||||
err := m.checkPATFromRequest(w, r, auth)
|
||||
if err != nil {
|
||||
log.Debugf("Error when validating PAT claims: %s", err.Error())
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT claims: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, r)
|
||||
default:
|
||||
util.WriteError(status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||
return
|
||||
}
|
||||
claims := m.claimsExtractor.FromRequestContext(r)
|
||||
//nolint
|
||||
ctx := context.WithValue(r.Context(), nbContext.UserIDKey, claims.UserId)
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, claims.AccountId)
|
||||
h.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
@ -114,7 +119,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
}
|
||||
|
||||
validatedToken, err := m.validateAndParseToken(token)
|
||||
validatedToken, err := m.validateAndParseToken(r.Context(), token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -123,7 +128,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.verifyUserAccess(validatedToken); err != nil {
|
||||
if err := m.verifyUserAccess(r.Context(), validatedToken); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -138,9 +143,9 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
// verifyUserAccess checks if a user, based on a validated JWT token,
|
||||
// is allowed access, particularly in cases where the admin enabled JWT
|
||||
// group propagation and designated certain groups with access permissions.
|
||||
func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error {
|
||||
func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *jwt.Token) error {
|
||||
authClaims := m.claimsExtractor.FromToken(validatedToken)
|
||||
return m.checkUserAccessByJWTGroups(authClaims)
|
||||
return m.checkUserAccessByJWTGroups(ctx, authClaims)
|
||||
}
|
||||
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
@ -152,7 +157,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
return fmt.Errorf("Error extracting token: %w", err)
|
||||
}
|
||||
|
||||
account, user, pat, err := m.getAccountFromPAT(token)
|
||||
account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid Token: %w", err)
|
||||
}
|
||||
@ -160,7 +165,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
return fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
err = m.markPATUsed(pat.ID)
|
||||
err = m.markPATUsed(r.Context(), pat.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -15,15 +16,16 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
audience = "audience"
|
||||
userIDClaim = "userIDClaim"
|
||||
accountID = "accountID"
|
||||
domain = "domain"
|
||||
userID = "userID"
|
||||
tokenID = "tokenID"
|
||||
PAT = "nbp_PAT"
|
||||
JWT = "JWT"
|
||||
wrongToken = "wrongToken"
|
||||
audience = "audience"
|
||||
userIDClaim = "userIDClaim"
|
||||
accountID = "accountID"
|
||||
domain = "domain"
|
||||
domainCategory = "domainCategory"
|
||||
userID = "userID"
|
||||
tokenID = "tokenID"
|
||||
PAT = "nbp_PAT"
|
||||
JWT = "JWT"
|
||||
wrongToken = "wrongToken"
|
||||
)
|
||||
|
||||
var testAccount = &server.Account{
|
||||
@ -47,14 +49,14 @@ var testAccount = &server.Account{
|
||||
},
|
||||
}
|
||||
|
||||
func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
if token == PAT {
|
||||
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
|
||||
}
|
||||
return nil, nil, nil, fmt.Errorf("PAT invalid")
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(token string) (*jwt.Token, error) {
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
|
||||
if token == JWT {
|
||||
return &jwt.Token{
|
||||
Claims: jwt.MapClaims{
|
||||
@ -67,14 +69,14 @@ func mockValidateAndParseToken(token string) (*jwt.Token, error) {
|
||||
return nil, fmt.Errorf("JWT invalid")
|
||||
}
|
||||
|
||||
func mockMarkPATUsed(token string) error {
|
||||
func mockMarkPATUsed(_ context.Context, token string) error {
|
||||
if token == tokenID {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("Should never get reached")
|
||||
}
|
||||
|
||||
func mockCheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error {
|
||||
func mockCheckUserAccessByJWTGroups(_ context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||
if testAccount.Id != claims.AccountId {
|
||||
return fmt.Errorf("account with id %s does not exist", claims.AccountId)
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r *
|
||||
for bypassPath := range bypassPaths {
|
||||
matched, err := path.Match(bypassPath, requestPath)
|
||||
if err != nil {
|
||||
log.Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err)
|
||||
log.WithContext(r.Context()).Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err)
|
||||
continue
|
||||
}
|
||||
if matched {
|
||||
|
@ -36,16 +36,16 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg
|
||||
// GetAllNameservers returns the list of nameserver groups for the account
|
||||
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroups, err := h.accountManager.ListNameServerGroups(account.Id, user.Id)
|
||||
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -54,15 +54,15 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
|
||||
apiNameservers = append(apiNameservers, toNameserverGroupResponse(r))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, apiNameservers)
|
||||
util.WriteJSONObject(r.Context(), w, apiNameservers)
|
||||
}
|
||||
|
||||
// CreateNameserverGroup handles nameserver group creation request
|
||||
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -75,33 +75,33 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
|
||||
nsList, err := toServerNSList(req.Nameservers)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled)
|
||||
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toNameserverGroupResponse(nsGroup)
|
||||
|
||||
util.WriteJSONObject(w, &resp)
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
|
||||
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID
|
||||
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||
if len(nsGroupID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -114,7 +114,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
|
||||
nsList, err := toServerNSList(req.Nameservers)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid NS servers format"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -130,66 +130,66 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
SearchDomainsEnabled: req.SearchDomainsEnabled,
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveNameServerGroup(account.Id, user.Id, updatedNSGroup)
|
||||
err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toNameserverGroupResponse(updatedNSGroup)
|
||||
|
||||
util.WriteJSONObject(w, &resp)
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
|
||||
// DeleteNameserverGroup handles nameserver group deletion request
|
||||
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||
if len(nsGroupID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID, user.Id)
|
||||
err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
// GetNameserverGroup handles a nameserver group Get request identified by ID
|
||||
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||
if len(nsGroupID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, user.Id, nsGroupID)
|
||||
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toNameserverGroupResponse(nsGroup)
|
||||
|
||||
util.WriteJSONObject(w, &resp)
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
|
||||
func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) {
|
||||
|
@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -61,13 +62,13 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{
|
||||
func initNameserversTestData() *NameserversHandler {
|
||||
return &NameserversHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetNameServerGroupFunc: func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
if nsGroupID == existingNSGroupID {
|
||||
return baseExistingNSGroup.Copy(), nil
|
||||
}
|
||||
return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
|
||||
},
|
||||
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) {
|
||||
CreateNameServerGroupFunc: func(_ context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, _ string, searchDomains bool) (*nbdns.NameServerGroup, error) {
|
||||
return &nbdns.NameServerGroup{
|
||||
ID: existingNSGroupID,
|
||||
Name: name,
|
||||
@ -80,16 +81,16 @@ func initNameserversTestData() *NameserversHandler {
|
||||
SearchDomainsEnabled: searchDomains,
|
||||
}, nil
|
||||
},
|
||||
DeleteNameServerGroupFunc: func(accountID, nsGroupID, _ string) error {
|
||||
DeleteNameServerGroupFunc: func(_ context.Context, accountID, nsGroupID, _ string) error {
|
||||
return nil
|
||||
},
|
||||
SaveNameServerGroupFunc: func(accountID, _ string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||
SaveNameServerGroupFunc: func(_ context.Context, accountID, _ string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||
if nsGroupToSave.ID == existingNSGroupID {
|
||||
return nil
|
||||
}
|
||||
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testingNSAccount, testingAccount.Users["test_user"], nil
|
||||
},
|
||||
},
|
||||
|
@ -34,22 +34,22 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH
|
||||
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
||||
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
userID := vars["userId"]
|
||||
if len(userID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
pats, err := h.accountManager.GetAllPATs(account.Id, user.Id, userID)
|
||||
pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -58,53 +58,53 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
patResponse = append(patResponse, toPATResponse(pat))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, patResponse)
|
||||
util.WriteJSONObject(r.Context(), w, patResponse)
|
||||
}
|
||||
|
||||
// GetToken is HTTP GET handler that returns a personal access token for the given user
|
||||
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
tokenID := vars["tokenId"]
|
||||
if len(tokenID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid token ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
pat, err := h.accountManager.GetPAT(account.Id, user.Id, targetUserID, tokenID)
|
||||
pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, toPATResponse(pat))
|
||||
util.WriteJSONObject(r.Context(), w, toPATResponse(pat))
|
||||
}
|
||||
|
||||
// CreateToken is HTTP POST handler that creates a personal access token for the given user
|
||||
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -115,44 +115,44 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
pat, err := h.accountManager.CreatePAT(account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn)
|
||||
pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, toPATGeneratedResponse(pat))
|
||||
util.WriteJSONObject(r.Context(), w, toPATGeneratedResponse(pat))
|
||||
}
|
||||
|
||||
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
||||
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
tokenID := vars["tokenId"]
|
||||
if len(tokenID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid token ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeletePAT(account.Id, user.Id, targetUserID, tokenID)
|
||||
err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken {
|
||||
|
@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -63,7 +64,7 @@ var testAccount = &server.Account{
|
||||
func initPATTestData() *PATHandler {
|
||||
return &PATHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
CreatePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
@ -76,10 +77,10 @@ func initPATTestData() *PATHandler {
|
||||
}, nil
|
||||
},
|
||||
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testAccount, testAccount.Users[existingUserID], nil
|
||||
},
|
||||
DeletePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
if accountID != existingAccountID {
|
||||
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
@ -91,7 +92,7 @@ func initPATTestData() *PATHandler {
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetPATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
@ -103,7 +104,7 @@ func initPATTestData() *PATHandler {
|
||||
}
|
||||
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
|
||||
},
|
||||
GetAllPATsFunc: func(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -47,16 +48,16 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
return peerToReturn, nil
|
||||
}
|
||||
|
||||
func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w http.ResponseWriter) {
|
||||
peer, err := h.accountManager.GetPeer(account.Id, peerID, userID)
|
||||
func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) {
|
||||
peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
peerToReturn, err := h.checkPeerStatus(peer)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
@ -65,19 +66,19 @@ func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
log.Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(fmt.Errorf("internal error"), w)
|
||||
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
|
||||
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers)
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
|
||||
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
util.WriteJSONObject(w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
|
||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
|
||||
}
|
||||
|
||||
func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
req := &api.PeerRequest{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@ -99,9 +100,9 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe
|
||||
}
|
||||
}
|
||||
|
||||
peer, err := h.accountManager.UpdatePeer(account.Id, user.Id, update)
|
||||
peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
@ -110,75 +111,75 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
log.Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(fmt.Errorf("internal error"), w)
|
||||
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers)
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
|
||||
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
|
||||
util.WriteJSONObject(w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid))
|
||||
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid))
|
||||
}
|
||||
|
||||
func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w http.ResponseWriter) {
|
||||
err := h.accountManager.DeletePeer(accountID, peerID, userID)
|
||||
func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
|
||||
err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete peer: %v", err)
|
||||
util.WriteError(err, w)
|
||||
log.WithContext(ctx).Errorf("failed to delete peer: %v", err)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(ctx, w, emptyObject{})
|
||||
}
|
||||
|
||||
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
|
||||
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodDelete:
|
||||
h.deletePeer(account.Id, user.Id, peerID, w)
|
||||
h.deletePeer(r.Context(), account.Id, user.Id, peerID, w)
|
||||
return
|
||||
case http.MethodPut:
|
||||
h.updatePeer(account, user, peerID, w, r)
|
||||
h.updatePeer(r.Context(), account, user, peerID, w, r)
|
||||
return
|
||||
case http.MethodGet:
|
||||
h.getPeer(account, peerID, user.Id, w)
|
||||
h.getPeer(r.Context(), account, peerID, user.Id, w)
|
||||
return
|
||||
default:
|
||||
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllPeers returns a list of all peers associated with a provided account
|
||||
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
peers, err := h.accountManager.GetPeers(account.Id, user.Id)
|
||||
peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -188,34 +189,34 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
for _, peer := range peers {
|
||||
peerToReturn, err := h.checkPeerStatus(peer)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
|
||||
accessiblePeerNumbers, _ := h.accessiblePeersNumber(account, peer.ID)
|
||||
accessiblePeerNumbers, _ := h.accessiblePeersNumber(r.Context(), account, peer.ID)
|
||||
|
||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers))
|
||||
}
|
||||
|
||||
validPeersMap, err := h.accountManager.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
log.Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(fmt.Errorf("internal error"), w)
|
||||
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
h.setApprovalRequiredFlag(respBody, validPeersMap)
|
||||
|
||||
util.WriteJSONObject(w, respBody)
|
||||
util.WriteJSONObject(r.Context(), w, respBody)
|
||||
}
|
||||
|
||||
func (h *PeersHandler) accessiblePeersNumber(account *server.Account, peerID string) (int, error) {
|
||||
func (h *PeersHandler) accessiblePeersNumber(ctx context.Context, account *server.Account, peerID string) (int, error) {
|
||||
validatedPeersMap, err := h.accountManager.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validatedPeersMap)
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validatedPeersMap)
|
||||
return len(netMap.Peers) + len(netMap.OfflinePeers), nil
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
@ -29,7 +30,7 @@ const noUpdateChannelTestPeerID = "no-update-channel"
|
||||
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
return &PeersHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
UpdatePeerFunc: func(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
var p *nbpeer.Peer
|
||||
for _, peer := range peers {
|
||||
if update.ID == peer.ID {
|
||||
@ -42,7 +43,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
p.Name = update.Name
|
||||
return p, nil
|
||||
},
|
||||
GetPeerFunc: func(accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
GetPeerFunc: func(_ context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
var p *nbpeer.Peer
|
||||
for _, peer := range peers {
|
||||
if peerID == peer.ID {
|
||||
@ -52,13 +53,13 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
||||
}
|
||||
return p, nil
|
||||
},
|
||||
GetPeersFunc: func(accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
return peers, nil
|
||||
},
|
||||
GetDNSDomainFunc: func() string {
|
||||
return "netbird.selfhosted"
|
||||
},
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
|
@ -35,15 +35,15 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *
|
||||
// GetAllPolicies list for the account
|
||||
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPolicies, err := h.accountManager.ListPolicies(account.Id, user.Id)
|
||||
accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -51,28 +51,28 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||
for _, policy := range accountPolicies {
|
||||
resp := toPolicyResponse(account, policy)
|
||||
if len(resp.Rules) == 0 {
|
||||
util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||
return
|
||||
}
|
||||
policies = append(policies, resp)
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, policies)
|
||||
util.WriteJSONObject(r.Context(), w, policies)
|
||||
}
|
||||
|
||||
// UpdatePolicy handles update to a policy identified by a given ID
|
||||
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -84,7 +84,7 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
if policyIdx < 0 {
|
||||
util.WriteError(status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -94,9 +94,9 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
// CreatePolicy handles policy creation request
|
||||
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -118,12 +118,12 @@ func (h *Policies) savePolicy(
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Rules) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "policy rules shouldn't be empty"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "policy rules shouldn't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -137,31 +137,31 @@ func (h *Policies) savePolicy(
|
||||
Enabled: req.Enabled,
|
||||
Description: req.Description,
|
||||
}
|
||||
for _, r := range req.Rules {
|
||||
for _, rule := range req.Rules {
|
||||
pr := server.PolicyRule{
|
||||
ID: policyID, //TODO: when policy can contain multiple rules, need refactor
|
||||
Name: r.Name,
|
||||
Destinations: groupMinimumsToStrings(account, r.Destinations),
|
||||
Sources: groupMinimumsToStrings(account, r.Sources),
|
||||
Bidirectional: r.Bidirectional,
|
||||
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
|
||||
Name: rule.Name,
|
||||
Destinations: groupMinimumsToStrings(account, rule.Destinations),
|
||||
Sources: groupMinimumsToStrings(account, rule.Sources),
|
||||
Bidirectional: rule.Bidirectional,
|
||||
}
|
||||
|
||||
pr.Enabled = r.Enabled
|
||||
if r.Description != nil {
|
||||
pr.Description = *r.Description
|
||||
pr.Enabled = rule.Enabled
|
||||
if rule.Description != nil {
|
||||
pr.Description = *rule.Description
|
||||
}
|
||||
|
||||
switch r.Action {
|
||||
switch rule.Action {
|
||||
case api.PolicyRuleUpdateActionAccept:
|
||||
pr.Action = server.PolicyTrafficActionAccept
|
||||
case api.PolicyRuleUpdateActionDrop:
|
||||
pr.Action = server.PolicyTrafficActionDrop
|
||||
default:
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "unknown action type"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown action type"), w)
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Protocol {
|
||||
switch rule.Protocol {
|
||||
case api.PolicyRuleUpdateProtocolAll:
|
||||
pr.Protocol = server.PolicyRuleProtocolALL
|
||||
case api.PolicyRuleUpdateProtocolTcp:
|
||||
@ -171,14 +171,14 @@ func (h *Policies) savePolicy(
|
||||
case api.PolicyRuleUpdateProtocolIcmp:
|
||||
pr.Protocol = server.PolicyRuleProtocolICMP
|
||||
default:
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "unknown protocol type: %v", r.Protocol), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Ports != nil && len(*r.Ports) != 0 {
|
||||
for _, v := range *r.Ports {
|
||||
if rule.Ports != nil && len(*rule.Ports) != 0 {
|
||||
for _, v := range *rule.Ports {
|
||||
if port, err := strconv.Atoi(v); err != nil || port < 1 || port > 65535 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w)
|
||||
return
|
||||
}
|
||||
pr.Ports = append(pr.Ports, v)
|
||||
@ -189,16 +189,16 @@ func (h *Policies) savePolicy(
|
||||
switch pr.Protocol {
|
||||
case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP:
|
||||
if len(pr.Ports) != 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
|
||||
return
|
||||
}
|
||||
if !pr.Bidirectional {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
|
||||
return
|
||||
}
|
||||
case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP:
|
||||
if !pr.Bidirectional && len(pr.Ports) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -210,26 +210,26 @@ func (h *Policies) savePolicy(
|
||||
policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks)
|
||||
}
|
||||
|
||||
if err := h.accountManager.SavePolicy(account.Id, user.Id, &policy); err != nil {
|
||||
util.WriteError(err, w)
|
||||
if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toPolicyResponse(account, &policy)
|
||||
if len(resp.Rules) == 0 {
|
||||
util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, resp)
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
// DeletePolicy handles policy deletion request
|
||||
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
aID := account.Id
|
||||
@ -237,24 +237,24 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.accountManager.DeletePolicy(aID, policyID, user.Id); err != nil {
|
||||
util.WriteError(err, w)
|
||||
if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
// GetPolicy handles a group Get request identified by ID
|
||||
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -263,25 +263,25 @@ func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
policy, err := h.accountManager.GetPolicy(account.Id, policyID, user.Id)
|
||||
policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toPolicyResponse(account, policy)
|
||||
if len(resp.Rules) == 0 {
|
||||
util.WriteError(status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, resp)
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
default:
|
||||
util.WriteError(status.Errorf(status.NotFound, "method not found"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -30,21 +31,21 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
||||
}
|
||||
return &Policies{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetPolicyFunc: func(_, policyID, _ string) (*server.Policy, error) {
|
||||
GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) {
|
||||
policy, ok := testPolicies[policyID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "policy not found")
|
||||
}
|
||||
return policy, nil
|
||||
},
|
||||
SavePolicyFunc: func(_, _ string, policy *server.Policy) error {
|
||||
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error {
|
||||
if !strings.HasPrefix(policy.ID, "id-") {
|
||||
policy.ID = "id-was-set"
|
||||
policy.Rules[0].ID = "id-was-set"
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
|
@ -37,15 +37,15 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa
|
||||
// GetAllPostureChecks list for the account
|
||||
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPostureChecks, err := p.accountManager.ListPostureChecks(account.Id, user.Id)
|
||||
accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -54,22 +54,22 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt
|
||||
postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, postureChecks)
|
||||
util.WriteJSONObject(r.Context(), w, postureChecks)
|
||||
}
|
||||
|
||||
// UpdatePostureCheck handles update to a posture check identified by a given ID
|
||||
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -81,7 +81,7 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
|
||||
}
|
||||
}
|
||||
if postureChecksIdx < 0 {
|
||||
util.WriteError(status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -91,9 +91,9 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
|
||||
// CreatePostureCheck handles posture check creation request
|
||||
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -103,50 +103,50 @@ func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http
|
||||
// GetPostureCheck handles a posture check Get request identified by ID
|
||||
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
postureChecks, err := p.accountManager.GetPostureChecks(account.Id, postureChecksID, user.Id)
|
||||
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, postureChecks.ToAPIResponse())
|
||||
util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse())
|
||||
}
|
||||
|
||||
// DeletePostureCheck handles posture check deletion request
|
||||
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
claims := p.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = p.accountManager.DeletePostureChecks(account.Id, postureChecksID, user.Id); err != nil {
|
||||
util.WriteError(err, w)
|
||||
if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
// savePostureChecks handles posture checks create and update
|
||||
@ -169,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks(
|
||||
|
||||
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
|
||||
if p.geolocationManager == nil {
|
||||
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
|
||||
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
|
||||
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
|
||||
return
|
||||
}
|
||||
@ -177,14 +177,14 @@ func (p *PostureChecksHandler) savePostureChecks(
|
||||
|
||||
postureChecks, err := posture.NewChecksFromAPIPostureCheckUpdate(req, postureChecksID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.accountManager.SavePostureChecks(account.Id, user.Id, postureChecks); err != nil {
|
||||
util.WriteError(err, w)
|
||||
if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, postureChecks.ToAPIResponse())
|
||||
util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse())
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -33,14 +34,14 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
||||
|
||||
return &PostureChecksHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetPostureChecksFunc: func(accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
GetPostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
p, ok := testPostureChecks[postureChecksID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "posture checks not found")
|
||||
}
|
||||
return p, nil
|
||||
},
|
||||
SavePostureChecksFunc: func(accountID, userID string, postureChecks *posture.Checks) error {
|
||||
SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||
postureChecks.ID = "postureCheck"
|
||||
testPostureChecks[postureChecks.ID] = postureChecks
|
||||
|
||||
@ -50,7 +51,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
||||
|
||||
return nil
|
||||
},
|
||||
DeletePostureChecksFunc: func(accountID, postureChecksID, userID string) error {
|
||||
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
|
||||
_, ok := testPostureChecks[postureChecksID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "posture checks not found")
|
||||
@ -59,14 +60,14 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
||||
|
||||
return nil
|
||||
},
|
||||
ListPostureChecksFunc: func(accountID, userID string) ([]*posture.Checks, error) {
|
||||
ListPostureChecksFunc: func(_ context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
||||
accountPostureChecks := make([]*posture.Checks, len(testPostureChecks))
|
||||
for _, p := range testPostureChecks {
|
||||
accountPostureChecks = append(accountPostureChecks, p)
|
||||
}
|
||||
return accountPostureChecks, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
user := server.NewAdminUser("test_user")
|
||||
return &server.Account{
|
||||
Id: claims.AccountId,
|
||||
|
@ -43,36 +43,36 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro
|
||||
// GetAllRoutes returns the list of routes for the account
|
||||
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routes, err := h.accountManager.ListRoutes(account.Id, user.Id)
|
||||
routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
apiRoutes := make([]*api.Route, 0)
|
||||
for _, r := range routes {
|
||||
route, err := toRouteResponse(r)
|
||||
for _, route := range routes {
|
||||
route, err := toRouteResponse(route)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
return
|
||||
}
|
||||
apiRoutes = append(apiRoutes, route)
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, apiRoutes)
|
||||
util.WriteJSONObject(r.Context(), w, apiRoutes)
|
||||
}
|
||||
|
||||
// CreateRoute handles route creation request
|
||||
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -84,7 +84,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if err := h.validateRoute(req); err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -94,7 +94,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
if req.Domains != nil {
|
||||
d, err := validateDomains(*req.Domains)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
|
||||
return
|
||||
}
|
||||
domains = d
|
||||
@ -102,7 +102,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
} else if req.Network != nil {
|
||||
networkType, newPrefix, err = route.ParseNetwork(*req.Network)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -120,24 +120,24 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
// Do not allow non-Linux peers
|
||||
if peer := account.GetPeer(peerId); peer != nil {
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
newRoute, err := h.accountManager.CreateRoute(account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
|
||||
newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routes, err := toRouteResponse(newRoute)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, routes)
|
||||
util.WriteJSONObject(r.Context(), w, routes)
|
||||
}
|
||||
|
||||
func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) error {
|
||||
@ -168,22 +168,22 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
|
||||
// UpdateRoute handles update to a route identified by a given ID
|
||||
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
routeID := vars["routeId"]
|
||||
if len(routeID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id)
|
||||
_, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -195,7 +195,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if err := h.validateRoute(req); err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -207,7 +207,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
// do not allow non Linux peers
|
||||
if peer := account.GetPeer(peerID); peer != nil {
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -226,7 +226,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
if req.Domains != nil {
|
||||
d, err := validateDomains(*req.Domains)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
|
||||
return
|
||||
}
|
||||
newRoute.Domains = d
|
||||
@ -234,7 +234,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
} else if req.Network != nil {
|
||||
newRoute.NetworkType, newRoute.Network, err = route.ParseNetwork(*req.Network)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -247,73 +247,73 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
newRoute.PeerGroups = *req.PeerGroups
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveRoute(account.Id, user.Id, newRoute)
|
||||
err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routes, err := toRouteResponse(newRoute)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, routes)
|
||||
util.WriteJSONObject(r.Context(), w, routes)
|
||||
}
|
||||
|
||||
// DeleteRoute handles route deletion request
|
||||
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routeID := mux.Vars(r)["routeId"]
|
||||
if len(routeID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteRoute(account.Id, route.ID(routeID), user.Id)
|
||||
err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
// GetRoute handles a route Get request identified by ID
|
||||
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routeID := mux.Vars(r)["routeId"]
|
||||
if len(routeID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
foundRoute, err := h.accountManager.GetRoute(account.Id, route.ID(routeID), user.Id)
|
||||
foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.NotFound, "route not found"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
|
||||
return
|
||||
}
|
||||
|
||||
routes, err := toRouteResponse(foundRoute)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.Internal, failedToConvertRoute, err), w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, routes)
|
||||
util.WriteJSONObject(r.Context(), w, routes)
|
||||
}
|
||||
|
||||
func toRouteResponse(serverRoute *route.Route) (*api.Route, error) {
|
||||
|
@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -89,7 +90,7 @@ var testingAccount = &server.Account{
|
||||
func initRoutesTestData() *RoutesHandler {
|
||||
return &RoutesHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetRouteFunc: func(_ string, routeID route.ID, _ string) (*route.Route, error) {
|
||||
GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) {
|
||||
if routeID == existingRouteID {
|
||||
return baseExistingRoute, nil
|
||||
}
|
||||
@ -104,7 +105,7 @@ func initRoutesTestData() *RoutesHandler {
|
||||
}
|
||||
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
||||
},
|
||||
CreateRouteFunc: func(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
|
||||
CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
|
||||
if peerID == notFoundPeerID {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||
}
|
||||
@ -126,19 +127,19 @@ func initRoutesTestData() *RoutesHandler {
|
||||
KeepRoute: keepRoute,
|
||||
}, nil
|
||||
},
|
||||
SaveRouteFunc: func(_, _ string, r *route.Route) error {
|
||||
SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error {
|
||||
if r.Peer == notFoundPeerID {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
DeleteRouteFunc: func(_ string, routeID route.ID, _ string) error {
|
||||
DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error {
|
||||
if routeID != existingRouteID {
|
||||
return status.Errorf(status.NotFound, "Peer with ID %s not found", routeID)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testingAccount, testingAccount.Users["test_user"], nil
|
||||
},
|
||||
},
|
||||
|
@ -1,6 +1,7 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
@ -34,9 +35,9 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg)
|
||||
// CreateSetupKey is a POST requests that creates a new SetupKey
|
||||
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -48,13 +49,13 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name shouldn't be empty"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable ||
|
||||
server.SetupKeyType(req.Type) == server.SetupKeyOneOff) {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -63,7 +64,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
day := time.Hour * 24
|
||||
year := day * 365
|
||||
if expiresIn < day || expiresIn > year {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -75,54 +76,54 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
if req.Ephemeral != nil {
|
||||
ephemeral = *req.Ephemeral
|
||||
}
|
||||
setupKey, err := h.accountManager.CreateSetupKey(account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
||||
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
|
||||
req.AutoGroups, req.UsageLimit, user.Id, ephemeral)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
writeSuccess(w, setupKey)
|
||||
writeSuccess(r.Context(), w, setupKey)
|
||||
}
|
||||
|
||||
// GetSetupKey is a GET request to get a SetupKey by ID
|
||||
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
if len(keyID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID)
|
||||
key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
writeSuccess(w, key)
|
||||
writeSuccess(r.Context(), w, key)
|
||||
}
|
||||
|
||||
// UpdateSetupKey is a PUT request to update server.SetupKey
|
||||
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
if len(keyID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -134,12 +135,12 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.AutoGroups == nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -149,26 +150,26 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
||||
newKey.Name = req.Name
|
||||
newKey.Id = keyID
|
||||
|
||||
newKey, err = h.accountManager.SaveSetupKey(account.Id, newKey, user.Id)
|
||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
writeSuccess(w, newKey)
|
||||
writeSuccess(r.Context(), w, newKey)
|
||||
}
|
||||
|
||||
// GetAllSetupKeys is a GET request that returns a list of SetupKey
|
||||
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id)
|
||||
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -177,15 +178,15 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques
|
||||
apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, apiSetupKeys)
|
||||
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
|
||||
}
|
||||
|
||||
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
|
||||
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(200)
|
||||
err := json.NewEncoder(w).Encode(toResponseBody(key))
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -33,7 +34,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
) *SetupKeysHandler {
|
||||
return &SetupKeysHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return &server.Account{
|
||||
Id: testAccountID,
|
||||
Domain: "hotmail.com",
|
||||
@ -49,7 +50,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
},
|
||||
}, user, nil
|
||||
},
|
||||
CreateSetupKeyFunc: func(_ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
|
||||
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
|
||||
_ int, _ string, ephemeral bool,
|
||||
) (*server.SetupKey, error) {
|
||||
if keyName == newKey.Name || typ != newKey.Type {
|
||||
@ -59,7 +60,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
}
|
||||
return nil, fmt.Errorf("failed creating setup key")
|
||||
},
|
||||
GetSetupKeyFunc: func(accountID, userID, keyID string) (*server.SetupKey, error) {
|
||||
GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*server.SetupKey, error) {
|
||||
switch keyID {
|
||||
case defaultKey.Id:
|
||||
return defaultKey, nil
|
||||
@ -70,14 +71,14 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
|
||||
}
|
||||
},
|
||||
|
||||
SaveSetupKeyFunc: func(accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) {
|
||||
SaveSetupKeyFunc: func(_ context.Context, accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) {
|
||||
if key.Id == updatedSetupKey.Id {
|
||||
return updatedSetupKey, nil
|
||||
}
|
||||
return nil, status.Errorf(status.NotFound, "key %s not found", key.Id)
|
||||
},
|
||||
|
||||
ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) {
|
||||
ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) {
|
||||
return []*server.SetupKey{defaultKey}, nil
|
||||
},
|
||||
},
|
||||
|
@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
userID := vars["userId"]
|
||||
if len(userID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
existingUser, ok := account.Users[userID]
|
||||
if !ok {
|
||||
util.WriteError(status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -74,11 +74,11 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userRole := server.StrRoleToUserRole(req.Role)
|
||||
if userRole == server.UserRoleUnknown {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user role"), w)
|
||||
return
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.SaveUser(account.Id, user.Id, &server.User{
|
||||
newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{
|
||||
Id: userID,
|
||||
Role: userRole,
|
||||
AutoGroups: req.AutoGroups,
|
||||
@ -88,10 +88,10 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId))
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
|
||||
}
|
||||
|
||||
// DeleteUser is a DELETE request to delete a user
|
||||
@ -102,26 +102,26 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteUser(account.Id, user.Id, targetUserID)
|
||||
err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
// CreateUser creates a User in the system with a status "invited" (effectively this is a user invite).
|
||||
@ -132,9 +132,9 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -146,7 +146,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w)
|
||||
return
|
||||
}
|
||||
|
||||
@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
name = *req.Name
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.CreateUser(account.Id, user.Id, &server.UserInfo{
|
||||
newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{
|
||||
Email: email,
|
||||
Name: name,
|
||||
Role: req.Role,
|
||||
@ -169,10 +169,10 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
Issued: server.UserIssuedAPI,
|
||||
})
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(w, toUserResponse(newUser, claims.UserId))
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
|
||||
}
|
||||
|
||||
// GetAllUsers returns a list of users of the account this user belongs to.
|
||||
@ -184,42 +184,42 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id)
|
||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
serviceUser := r.URL.Query().Get("service_user")
|
||||
|
||||
users := make([]*api.User, 0)
|
||||
for _, r := range data {
|
||||
if r.NonDeletable {
|
||||
for _, d := range data {
|
||||
if d.NonDeletable {
|
||||
continue
|
||||
}
|
||||
if serviceUser == "" {
|
||||
users = append(users, toUserResponse(r, claims.UserId))
|
||||
users = append(users, toUserResponse(d, claims.UserId))
|
||||
continue
|
||||
}
|
||||
|
||||
includeServiceUser, err := strconv.ParseBool(serviceUser)
|
||||
log.Debugf("Should include service user: %v", includeServiceUser)
|
||||
log.WithContext(r.Context()).Debugf("Should include service user: %v", includeServiceUser)
|
||||
if err != nil {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
|
||||
return
|
||||
}
|
||||
if includeServiceUser == r.IsServiceUser {
|
||||
users = append(users, toUserResponse(r, claims.UserId))
|
||||
if includeServiceUser == d.IsServiceUser {
|
||||
users = append(users, toUserResponse(d, claims.UserId))
|
||||
}
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, users)
|
||||
util.WriteJSONObject(r.Context(), w, users)
|
||||
}
|
||||
|
||||
// InviteUser resend invitations to users who haven't activated their accounts,
|
||||
@ -231,26 +231,26 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.InviteUser(account.Id, user.Id, targetUserID)
|
||||
err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
||||
}
|
||||
|
||||
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
||||
|
@ -2,6 +2,7 @@ package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -63,10 +64,10 @@ var usersTestAccount = &server.Account{
|
||||
func initUsersTestData() *UsersHandler {
|
||||
return &UsersHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return usersTestAccount, usersTestAccount.Users[claims.UserId], nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
||||
users := make([]*server.UserInfo, 0)
|
||||
for _, v := range usersTestAccount.Users {
|
||||
users = append(users, &server.UserInfo{
|
||||
@ -81,13 +82,13 @@ func initUsersTestData() *UsersHandler {
|
||||
}
|
||||
return users, nil
|
||||
},
|
||||
CreateUserFunc: func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) {
|
||||
CreateUserFunc: func(_ context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) {
|
||||
if userID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
|
||||
}
|
||||
return key, nil
|
||||
},
|
||||
DeleteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
|
||||
DeleteUserFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) error {
|
||||
if targetUserID == notFoundUserID {
|
||||
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
|
||||
}
|
||||
@ -96,7 +97,7 @@ func initUsersTestData() *UsersHandler {
|
||||
}
|
||||
return nil
|
||||
},
|
||||
SaveUserFunc: func(accountID, userID string, update *server.User) (*server.UserInfo, error) {
|
||||
SaveUserFunc: func(_ context.Context, accountID, userID string, update *server.User) (*server.UserInfo, error) {
|
||||
if update.Id == notFoundUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id)
|
||||
}
|
||||
@ -111,7 +112,7 @@ func initUsersTestData() *UsersHandler {
|
||||
}
|
||||
return info, nil
|
||||
},
|
||||
InviteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
|
||||
InviteUserFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) error {
|
||||
if initiatorUserID != existingUserID {
|
||||
return status.Errorf(status.NotFound, "user with ID %s does not exists", initiatorUserID)
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -19,12 +20,12 @@ type ErrorResponse struct {
|
||||
}
|
||||
|
||||
// WriteJSONObject simply writes object to the HTTP response in JSON format
|
||||
func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
|
||||
func WriteJSONObject(ctx context.Context, w http.ResponseWriter, obj interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
err := json.NewEncoder(w).Encode(obj)
|
||||
if err != nil {
|
||||
WriteError(err, w)
|
||||
WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -76,8 +77,8 @@ func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
|
||||
|
||||
// WriteError converts an error to an JSON error response.
|
||||
// If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise
|
||||
func WriteError(err error, w http.ResponseWriter) {
|
||||
log.Errorf("got a handler error: %s", err.Error())
|
||||
func WriteError(ctx context.Context, err error, w http.ResponseWriter) {
|
||||
log.WithContext(ctx).Errorf("got a handler error: %s", err.Error())
|
||||
errStatus, ok := status.FromError(err)
|
||||
httpStatus := http.StatusInternalServerError
|
||||
msg := "internal server error"
|
||||
@ -106,7 +107,7 @@ func WriteError(err error, w http.ResponseWriter) {
|
||||
msg = strings.ToLower(err.Error())
|
||||
} else {
|
||||
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", err.Error())
|
||||
log.Error(unhandledMSG)
|
||||
log.WithContext(ctx).Error(unhandledMSG)
|
||||
}
|
||||
|
||||
WriteErrorResponse(msg, httpStatus, w)
|
||||
|
@ -183,7 +183,7 @@ func (c *Auth0Credentials) jwtStillValid() bool {
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token
|
||||
func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) {
|
||||
func (c *Auth0Credentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
var res *http.Response
|
||||
reqURL := c.clientConfig.AuthIssuer + "/oauth/token"
|
||||
|
||||
@ -200,7 +200,7 @@ func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) {
|
||||
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.Debug("requesting new jwt token for idp manager")
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for idp manager")
|
||||
|
||||
res, err = c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@ -247,7 +247,7 @@ func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTTo
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the Auth0 Management API
|
||||
func (c *Auth0Credentials) Authenticate() (JWTToken, error) {
|
||||
func (c *Auth0Credentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
@ -260,14 +260,14 @@ func (c *Auth0Credentials) Authenticate() (JWTToken, error) {
|
||||
return c.jwtToken, nil
|
||||
}
|
||||
|
||||
res, err := c.requestJWTToken()
|
||||
res, err := c.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return c.jwtToken, err
|
||||
}
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing get jwt token response body: %v", err)
|
||||
log.WithContext(ctx).Errorf("error while closing get jwt token response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -301,8 +301,8 @@ func requestByUserIDURL(authIssuer, userID string) string {
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile. Calls Auth0 API.
|
||||
func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
func (am *Auth0Manager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -353,7 +353,7 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch))
|
||||
log.WithContext(ctx).Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch))
|
||||
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
@ -365,7 +365,7 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
}
|
||||
|
||||
if len(batch) == 0 || len(batch) < resultsPerPage {
|
||||
log.Debugf("finished loading users for accountID %s", accountID)
|
||||
log.WithContext(ctx).Debugf("finished loading users for accountID %s", accountID)
|
||||
return list, nil
|
||||
}
|
||||
}
|
||||
@ -374,8 +374,8 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from auth0 via ID
|
||||
func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
func (am *Auth0Manager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -414,7 +414,7 @@ func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata)
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing update user app metadata response body: %v", err)
|
||||
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -426,9 +426,9 @@ func (am *Auth0Manager) GetUserDataByID(userID string, appMetadata AppMetadata)
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userId and metadata map
|
||||
func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||
func (am *Auth0Manager) UpdateUserAppMetadata(ctx context.Context, userID string, appMetadata AppMetadata) error {
|
||||
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -449,7 +449,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.Debugf("updating IdP metadata for user %s", userID)
|
||||
log.WithContext(ctx).Debugf("updating IdP metadata for user %s", userID)
|
||||
|
||||
res, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@ -466,7 +466,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing update user app metadata response body: %v", err)
|
||||
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -530,9 +530,9 @@ func buildUserExportRequest() (string, error) {
|
||||
}
|
||||
|
||||
func (am *Auth0Manager) createRequest(
|
||||
method string, endpoint string, body io.Reader,
|
||||
ctx context.Context, method string, endpoint string, body io.Reader,
|
||||
) (*http.Request, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -548,8 +548,8 @@ func (am *Auth0Manager) createRequest(
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) {
|
||||
req, err := am.createRequest("POST", endpoint, strings.NewReader(payloadStr))
|
||||
func (am *Auth0Manager) createPostRequest(ctx context.Context, endpoint string, payloadStr string) (*http.Request, error) {
|
||||
req, err := am.createRequest(ctx, "POST", endpoint, strings.NewReader(payloadStr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -560,20 +560,20 @@ func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
func (am *Auth0Manager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
payloadString, err := buildUserExportRequest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
exportJobReq, err := am.createPostRequest("/api/v2/jobs/users-exports", payloadString)
|
||||
exportJobReq, err := am.createPostRequest(ctx, "/api/v2/jobs/users-exports", payloadString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jobResp, err := am.httpClient.Do(exportJobReq)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't get job response %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
@ -583,7 +583,7 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
defer func() {
|
||||
err = jobResp.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing update user app metadata response body: %v", err)
|
||||
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
|
||||
}
|
||||
}()
|
||||
if jobResp.StatusCode != 200 {
|
||||
@ -597,13 +597,13 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
|
||||
body, err := io.ReadAll(jobResp.Body)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't read export job response; %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.helper.Unmarshal(body, &exportJobResp)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -614,16 +614,16 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
|
||||
}
|
||||
|
||||
log.Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
|
||||
log.WithContext(ctx).Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
|
||||
|
||||
done, downloadLink, err := am.checkExportJobStatus(exportJobResp.ID)
|
||||
done, downloadLink, err := am.checkExportJobStatus(ctx, exportJobResp.ID)
|
||||
if err != nil {
|
||||
log.Debugf("Failed at getting status checks from exportJob; %v", err)
|
||||
log.WithContext(ctx).Debugf("Failed at getting status checks from exportJob; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if done {
|
||||
return am.downloadProfileExport(downloadLink)
|
||||
return am.downloadProfileExport(ctx, downloadLink)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed extracting user profiles from auth0")
|
||||
@ -632,13 +632,13 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
// GetUserByEmail searches users with a given email. If no users have been found, this function returns an empty list.
|
||||
// This function can return multiple users. This is due to the Auth0 internals - there could be multiple users with
|
||||
// the same email but different connections that are considered as separate accounts (e.g., Google and username/password).
|
||||
func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
func (am *Auth0Manager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + url.QueryEscape(email)
|
||||
body, err := doGetReq(am.httpClient, reqURL, jwtToken.AccessToken)
|
||||
body, err := doGetReq(ctx, am.httpClient, reqURL, jwtToken.AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -651,7 +651,7 @@ func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
|
||||
err = am.helper.Unmarshal(body, &userResp)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -659,13 +659,13 @@ func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in Auth0 Idp and sends an invite
|
||||
func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
func (am *Auth0Manager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
|
||||
payloadString, err := buildCreateUserRequestPayload(email, name, accountID, invitedByEmail)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := am.createPostRequest("/api/v2/users", payloadString)
|
||||
req, err := am.createPostRequest(ctx, "/api/v2/users", payloadString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -676,7 +676,7 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't get job response %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
@ -686,7 +686,7 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing create user response body: %v", err)
|
||||
log.WithContext(ctx).Errorf("error while closing create user response body: %v", err)
|
||||
}
|
||||
}()
|
||||
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
|
||||
@ -700,13 +700,13 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't read export job response; %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.helper.Unmarshal(body, &createResp)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -714,14 +714,14 @@ func (am *Auth0Manager) CreateUser(email, name, accountID, invitedByEmail string
|
||||
return nil, fmt.Errorf("couldn't create user: response %v", resp)
|
||||
}
|
||||
|
||||
log.Debugf("created user %s in account %s", createResp.ID, accountID)
|
||||
log.WithContext(ctx).Debugf("created user %s in account %s", createResp.ID, accountID)
|
||||
|
||||
return &createResp, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (am *Auth0Manager) InviteUserByID(userID string) error {
|
||||
func (am *Auth0Manager) InviteUserByID(ctx context.Context, userID string) error {
|
||||
userVerificationReq := userVerificationJobRequest{
|
||||
UserID: userID,
|
||||
}
|
||||
@ -731,14 +731,14 @@ func (am *Auth0Manager) InviteUserByID(userID string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := am.createPostRequest("/api/v2/jobs/verification-email", string(payload))
|
||||
req, err := am.createPostRequest(ctx, "/api/v2/jobs/verification-email", string(payload))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't get job response %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
@ -748,7 +748,7 @@ func (am *Auth0Manager) InviteUserByID(userID string) error {
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing invite user response body: %v", err)
|
||||
log.WithContext(ctx).Errorf("error while closing invite user response body: %v", err)
|
||||
}
|
||||
}()
|
||||
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
|
||||
@ -762,15 +762,15 @@ func (am *Auth0Manager) InviteUserByID(userID string) error {
|
||||
}
|
||||
|
||||
// DeleteUser from Auth0
|
||||
func (am *Auth0Manager) DeleteUser(userID string) error {
|
||||
req, err := am.createRequest(http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil)
|
||||
func (am *Auth0Manager) DeleteUser(ctx context.Context, userID string) error {
|
||||
req, err := am.createRequest(ctx, http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Debugf("execute delete request: %v", err)
|
||||
log.WithContext(ctx).Debugf("execute delete request: %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
@ -780,7 +780,7 @@ func (am *Auth0Manager) DeleteUser(userID string) error {
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("close delete request body: %v", err)
|
||||
log.WithContext(ctx).Errorf("close delete request body: %v", err)
|
||||
}
|
||||
}()
|
||||
if resp.StatusCode != 204 {
|
||||
@ -795,20 +795,20 @@ func (am *Auth0Manager) DeleteUser(userID string) error {
|
||||
|
||||
// GetAllConnections returns detailed list of all connections filtered by given params.
|
||||
// Note this method is not part of the IDP Manager interface as this is Auth0 specific.
|
||||
func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, error) {
|
||||
func (am *Auth0Manager) GetAllConnections(ctx context.Context, strategy []string) ([]Connection, error) {
|
||||
var connections []Connection
|
||||
|
||||
q := make(url.Values)
|
||||
q.Set("strategy", strings.Join(strategy, ","))
|
||||
|
||||
req, err := am.createRequest(http.MethodGet, "/api/v2/connections?"+q.Encode(), nil)
|
||||
req, err := am.createRequest(ctx, http.MethodGet, "/api/v2/connections?"+q.Encode(), nil)
|
||||
if err != nil {
|
||||
return connections, err
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Debugf("execute get connections request: %v", err)
|
||||
log.WithContext(ctx).Debugf("execute get connections request: %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
@ -818,7 +818,7 @@ func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, erro
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("close get connections request body: %v", err)
|
||||
log.WithContext(ctx).Errorf("close get connections request body: %v", err)
|
||||
}
|
||||
}()
|
||||
if resp.StatusCode != 200 {
|
||||
@ -830,13 +830,13 @@ func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, erro
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't read get connections response; %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't read get connections response; %v", err)
|
||||
return connections, err
|
||||
}
|
||||
|
||||
err = am.helper.Unmarshal(body, &connections)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't unmarshal get connection response; %v", err)
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal get connection response; %v", err)
|
||||
return connections, err
|
||||
}
|
||||
|
||||
@ -845,23 +845,23 @@ func (am *Auth0Manager) GetAllConnections(strategy []string) ([]Connection, erro
|
||||
|
||||
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
|
||||
// If the status is "completed", then return the downloadLink
|
||||
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
|
||||
func (am *Auth0Manager) checkExportJobStatus(ctx context.Context, jobID string) (bool, string, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 90*time.Second)
|
||||
defer cancel()
|
||||
retry := time.NewTicker(10 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debugf("Export job status stopped...\n")
|
||||
log.WithContext(ctx).Debugf("Export job status stopped...\n")
|
||||
return false, "", ctx.Err()
|
||||
case <-retry.C:
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
statusURL := am.authIssuer + "/api/v2/jobs/" + jobID
|
||||
body, err := doGetReq(am.httpClient, statusURL, jwtToken.AccessToken)
|
||||
body, err := doGetReq(ctx, am.httpClient, statusURL, jwtToken.AccessToken)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
@ -872,7 +872,7 @@ func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error)
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
log.Debugf("current export job status is %v", status.Status)
|
||||
log.WithContext(ctx).Debugf("current export job status is %v", status.Status)
|
||||
|
||||
if status.Status != "completed" {
|
||||
continue
|
||||
@ -884,8 +884,8 @@ func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error)
|
||||
}
|
||||
|
||||
// downloadProfileExport downloads user profiles from auth0 batch job
|
||||
func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*UserData, error) {
|
||||
body, err := doGetReq(am.httpClient, location, "")
|
||||
func (am *Auth0Manager) downloadProfileExport(ctx context.Context, location string) (map[string][]*UserData, error) {
|
||||
body, err := doGetReq(ctx, am.httpClient, location, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -927,7 +927,7 @@ func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*Us
|
||||
}
|
||||
|
||||
// Boilerplate implementation for Get Requests.
|
||||
func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error) {
|
||||
func doGetReq(ctx context.Context, client ManagerHTTPClient, url, accessToken string) ([]byte, error) {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -945,7 +945,7 @@ func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error)
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing body for url %s: %v", url, err)
|
||||
log.WithContext(ctx).Errorf("error while closing body for url %s: %v", url, err)
|
||||
}
|
||||
}()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -60,7 +61,7 @@ type mockAuth0Credentials struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (mc *mockAuth0Credentials) Authenticate() (JWTToken, error) {
|
||||
func (mc *mockAuth0Credentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return mc.jwtToken, mc.err
|
||||
}
|
||||
|
||||
@ -126,7 +127,7 @@ func TestAuth0_RequestJWTToken(t *testing.T) {
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
res, err := creds.requestJWTToken()
|
||||
res, err := creds.requestJWTToken(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
@ -295,7 +296,7 @@ func TestAuth0_Authenticate(t *testing.T) {
|
||||
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate()
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
@ -417,7 +418,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
|
||||
err := manager.UpdateUserAppMetadata(context.Background(), "1", testCase.appMetadata)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equal(t, testCase.expectedReqBody, jwtReqClient.reqBody, "request body should match")
|
||||
|
@ -116,7 +116,7 @@ func (ac *AuthentikCredentials) jwtStillValid() bool {
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (ac *AuthentikCredentials) requestJWTToken() (*http.Response, error) {
|
||||
func (ac *AuthentikCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", ac.clientConfig.ClientID)
|
||||
data.Set("username", ac.clientConfig.Username)
|
||||
@ -131,7 +131,7 @@ func (ac *AuthentikCredentials) requestJWTToken() (*http.Response, error) {
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.Debug("requesting new jwt token for authentik idp manager")
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for authentik idp manager")
|
||||
|
||||
resp, err := ac.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@ -183,7 +183,7 @@ func (ac *AuthentikCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the authentik management API.
|
||||
func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) {
|
||||
func (ac *AuthentikCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
ac.mux.Lock()
|
||||
defer ac.mux.Unlock()
|
||||
|
||||
@ -197,7 +197,7 @@ func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) {
|
||||
return ac.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := ac.requestJWTToken()
|
||||
resp, err := ac.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return ac.jwtToken, err
|
||||
}
|
||||
@ -214,13 +214,13 @@ func (ac *AuthentikCredentials) Authenticate() (JWTToken, error) {
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (am *AuthentikManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
||||
func (am *AuthentikManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from authentik via ID.
|
||||
func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
ctx, err := am.authenticationContext()
|
||||
func (am *AuthentikManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -254,8 +254,8 @@ func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetada
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
users, err := am.getAllUsers()
|
||||
func (am *AuthentikManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -274,8 +274,8 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
users, err := am.getAllUsers()
|
||||
func (am *AuthentikManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -291,12 +291,12 @@ func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
}
|
||||
|
||||
// getAllUsers returns all users in a Authentik account.
|
||||
func (am *AuthentikManager) getAllUsers() ([]*UserData, error) {
|
||||
func (am *AuthentikManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||
users := make([]*UserData, 0)
|
||||
|
||||
page := int32(1)
|
||||
for {
|
||||
ctx, err := am.authenticationContext()
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -329,14 +329,14 @@ func (am *AuthentikManager) getAllUsers() ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in authentik Idp and sends an invitation.
|
||||
func (am *AuthentikManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||
func (am *AuthentikManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
ctx, err := am.authenticationContext()
|
||||
func (am *AuthentikManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -368,13 +368,13 @@ func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (am *AuthentikManager) InviteUserByID(_ string) error {
|
||||
func (am *AuthentikManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Authentik
|
||||
func (am *AuthentikManager) DeleteUser(userID string) error {
|
||||
ctx, err := am.authenticationContext()
|
||||
func (am *AuthentikManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -404,8 +404,8 @@ func (am *AuthentikManager) DeleteUser(userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *AuthentikManager) authenticationContext() (context.Context, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
func (am *AuthentikManager) authenticationContext(ctx context.Context) (context.Context, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
@ -138,7 +139,7 @@ func TestAuthentikRequestJWTToken(t *testing.T) {
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
resp, err := creds.requestJWTToken()
|
||||
resp, err := creds.requestJWTToken(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
@ -304,7 +305,7 @@ func TestAuthentikAuthenticate(t *testing.T) {
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate()
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
|
@ -1,6 +1,7 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -110,7 +111,7 @@ func (ac *AzureCredentials) jwtStillValid() bool {
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) {
|
||||
func (ac *AzureCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", ac.clientConfig.ClientID)
|
||||
data.Set("client_secret", ac.clientConfig.ClientSecret)
|
||||
@ -132,7 +133,7 @@ func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) {
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.Debug("requesting new jwt token for azure idp manager")
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for azure idp manager")
|
||||
|
||||
resp, err := ac.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@ -184,7 +185,7 @@ func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTT
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the azure Management API.
|
||||
func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
|
||||
func (ac *AzureCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
ac.mux.Lock()
|
||||
defer ac.mux.Unlock()
|
||||
|
||||
@ -198,7 +199,7 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
|
||||
return ac.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := ac.requestJWTToken()
|
||||
resp, err := ac.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return ac.jwtToken, err
|
||||
}
|
||||
@ -215,16 +216,16 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in azure AD Idp.
|
||||
func (am *AzureManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||
func (am *AzureManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
func (am *AzureManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
q := url.Values{}
|
||||
q.Add("$select", profileFields)
|
||||
|
||||
body, err := am.get("users/"+userID, q)
|
||||
body, err := am.get(ctx, "users/"+userID, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -247,11 +248,11 @@ func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata)
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
func (am *AzureManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
q := url.Values{}
|
||||
q.Add("$select", profileFields)
|
||||
|
||||
body, err := am.get("users/"+email, q)
|
||||
body, err := am.get(ctx, "users/"+email, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -273,8 +274,8 @@ func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
users, err := am.getAllUsers()
|
||||
func (am *AzureManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -293,8 +294,8 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
users, err := am.getAllUsers()
|
||||
func (am *AzureManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -310,19 +311,19 @@ func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID.
|
||||
func (am *AzureManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
||||
func (am *AzureManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (am *AzureManager) InviteUserByID(_ string) error {
|
||||
func (am *AzureManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Azure.
|
||||
func (am *AzureManager) DeleteUser(userID string) error {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
func (am *AzureManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -335,7 +336,7 @@ func (am *AzureManager) DeleteUser(userID string) error {
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.Debugf("delete idp user %s", userID)
|
||||
log.WithContext(ctx).Debugf("delete idp user %s", userID)
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@ -358,7 +359,7 @@ func (am *AzureManager) DeleteUser(userID string) error {
|
||||
}
|
||||
|
||||
// getAllUsers returns all users in an Azure AD account.
|
||||
func (am *AzureManager) getAllUsers() ([]*UserData, error) {
|
||||
func (am *AzureManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||
users := make([]*UserData, 0)
|
||||
|
||||
q := url.Values{}
|
||||
@ -366,7 +367,7 @@ func (am *AzureManager) getAllUsers() ([]*UserData, error) {
|
||||
q.Add("$top", "500")
|
||||
|
||||
for nextLink := "users"; nextLink != ""; {
|
||||
body, err := am.get(nextLink, q)
|
||||
body, err := am.get(ctx, nextLink, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -391,8 +392,8 @@ func (am *AzureManager) getAllUsers() ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// get perform Get requests.
|
||||
func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
func (am *AzureManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
@ -101,7 +102,7 @@ func TestAzureAuthenticate(t *testing.T) {
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate()
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
|
@ -39,12 +39,12 @@ type GoogleWorkspaceCredentials struct {
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
func (gc *GoogleWorkspaceCredentials) Authenticate() (JWTToken, error) {
|
||||
func (gc *GoogleWorkspaceCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
|
||||
// NewGoogleWorkspaceManager creates a new instance of the GoogleWorkspaceManager.
|
||||
func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) {
|
||||
func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
@ -66,7 +66,7 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te
|
||||
}
|
||||
|
||||
// Create a new Admin SDK Directory service client
|
||||
adminCredentials, err := getGoogleCredentials(config.ServiceAccountKey)
|
||||
adminCredentials, err := getGoogleCredentials(ctx, config.ServiceAccountKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -90,12 +90,12 @@ func NewGoogleWorkspaceManager(config GoogleWorkspaceClientConfig, appMetrics te
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
||||
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from Google Workspace via ID.
|
||||
func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
func (gm *GoogleWorkspaceManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
user, err := gm.usersService.Get(userID).Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -112,7 +112,7 @@ func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata App
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
func (gm *GoogleWorkspaceManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := gm.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -132,7 +132,7 @@ func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, err
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
func (gm *GoogleWorkspaceManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
||||
users, err := gm.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -177,13 +177,13 @@ func (gm *GoogleWorkspaceManager) getAllUsers() ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in Google Workspace and sends an invitation.
|
||||
func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||
func (gm *GoogleWorkspaceManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
func (gm *GoogleWorkspaceManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
||||
user, err := gm.usersService.Get(email).Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -201,12 +201,12 @@ func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, err
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (gm *GoogleWorkspaceManager) InviteUserByID(_ string) error {
|
||||
func (gm *GoogleWorkspaceManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from GoogleWorkspace.
|
||||
func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error {
|
||||
func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) error {
|
||||
if err := gm.usersService.Delete(userID).Do(); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -222,8 +222,8 @@ func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error {
|
||||
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
|
||||
// If that fails, it falls back to using the default Google credentials path.
|
||||
// It returns the retrieved credentials or an error if unsuccessful.
|
||||
func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error) {
|
||||
log.Debug("retrieving google credentials from the base64 encoded service account key")
|
||||
func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*google.Credentials, error) {
|
||||
log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key")
|
||||
decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode service account key: %w", err)
|
||||
@ -239,8 +239,8 @@ func getGoogleCredentials(serviceAccountKey string) (*google.Credentials, error)
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
log.Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err)
|
||||
log.Debug("falling back to default google credentials location")
|
||||
log.WithContext(ctx).Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err)
|
||||
log.WithContext(ctx).Debug("falling back to default google credentials location")
|
||||
|
||||
creds, err = google.FindDefaultCredentials(
|
||||
context.Background(),
|
||||
|
@ -1,6 +1,7 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@ -16,14 +17,14 @@ const (
|
||||
|
||||
// Manager idp manager interface
|
||||
type Manager interface {
|
||||
UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error
|
||||
GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error)
|
||||
GetAccount(accountId string) ([]*UserData, error)
|
||||
GetAllAccounts() (map[string][]*UserData, error)
|
||||
CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error)
|
||||
GetUserByEmail(email string) ([]*UserData, error)
|
||||
InviteUserByID(userID string) error
|
||||
DeleteUser(userID string) error
|
||||
UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error
|
||||
GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
|
||||
GetAccount(ctx context.Context, accountId string) ([]*UserData, error)
|
||||
GetAllAccounts(ctx context.Context) (map[string][]*UserData, error)
|
||||
CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
|
||||
GetUserByEmail(ctx context.Context, email string) ([]*UserData, error)
|
||||
InviteUserByID(ctx context.Context, userID string) error
|
||||
DeleteUser(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
// ClientConfig defines common client configuration for all IdP manager
|
||||
@ -51,7 +52,7 @@ type Config struct {
|
||||
|
||||
// ManagerCredentials interface that authenticates using the credential of each type of idp
|
||||
type ManagerCredentials interface {
|
||||
Authenticate() (JWTToken, error)
|
||||
Authenticate(ctx context.Context) (JWTToken, error)
|
||||
}
|
||||
|
||||
// ManagerHTTPClient http client interface for API calls
|
||||
@ -91,7 +92,7 @@ type JWTToken struct {
|
||||
}
|
||||
|
||||
// NewManager returns a new idp manager based on the configuration that it receives
|
||||
func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
|
||||
func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
|
||||
if config.ClientConfig != nil {
|
||||
config.ClientConfig.Issuer = strings.TrimSuffix(config.ClientConfig.Issuer, "/")
|
||||
}
|
||||
@ -175,7 +176,7 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error)
|
||||
ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"],
|
||||
CustomerID: config.ExtraConfig["CustomerId"],
|
||||
}
|
||||
return NewGoogleWorkspaceManager(googleClientConfig, appMetrics)
|
||||
return NewGoogleWorkspaceManager(ctx, googleClientConfig, appMetrics)
|
||||
case "jumpcloud":
|
||||
jumpcloudConfig := JumpCloudClientConfig{
|
||||
APIToken: config.ExtraConfig["ApiToken"],
|
||||
|
@ -74,7 +74,7 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the JumpCloud user API.
|
||||
func (jc *JumpCloudCredentials) Authenticate() (JWTToken, error) {
|
||||
func (jc *JumpCloudCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
|
||||
@ -85,12 +85,12 @@ func (jm *JumpCloudManager) authenticationContext() context.Context {
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
||||
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from JumpCloud via ID.
|
||||
func (jm *JumpCloudManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil)
|
||||
if err != nil {
|
||||
@ -116,7 +116,7 @@ func (jm *JumpCloudManager) GetUserDataByID(userID string, appMetadata AppMetada
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (jm *JumpCloudManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
|
||||
if err != nil {
|
||||
@ -148,7 +148,7 @@ func (jm *JumpCloudManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (jm *JumpCloudManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
|
||||
if err != nil {
|
||||
@ -177,13 +177,13 @@ func (jm *JumpCloudManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in JumpCloud Idp and sends an invitation.
|
||||
func (jm *JumpCloudManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||
func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (jm *JumpCloudManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
||||
searchFilter := map[string]interface{}{
|
||||
"searchFilter": map[string]interface{}{
|
||||
"filter": []string{email},
|
||||
@ -219,12 +219,12 @@ func (jm *JumpCloudManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (jm *JumpCloudManager) InviteUserByID(_ string) error {
|
||||
func (jm *JumpCloudManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from jumpCloud directory
|
||||
func (jm *JumpCloudManager) DeleteUser(userID string) error {
|
||||
func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error {
|
||||
authCtx := jm.authenticationContext()
|
||||
_, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil)
|
||||
if err != nil {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -109,7 +110,7 @@ func (kc *KeycloakCredentials) jwtStillValid() bool {
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (kc *KeycloakCredentials) requestJWTToken() (*http.Response, error) {
|
||||
func (kc *KeycloakCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kc.clientConfig.ClientID)
|
||||
data.Set("client_secret", kc.clientConfig.ClientSecret)
|
||||
@ -122,7 +123,7 @@ func (kc *KeycloakCredentials) requestJWTToken() (*http.Response, error) {
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.Debug("requesting new jwt token for keycloak idp manager")
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for keycloak idp manager")
|
||||
|
||||
resp, err := kc.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@ -174,7 +175,7 @@ func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (J
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the keycloak Management API.
|
||||
func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) {
|
||||
func (kc *KeycloakCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
kc.mux.Lock()
|
||||
defer kc.mux.Unlock()
|
||||
|
||||
@ -188,7 +189,7 @@ func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) {
|
||||
return kc.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := kc.requestJWTToken()
|
||||
resp, err := kc.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return kc.jwtToken, err
|
||||
}
|
||||
@ -205,18 +206,18 @@ func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in keycloak Idp and sends an invite.
|
||||
func (km *KeycloakManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||
func (km *KeycloakManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
func (km *KeycloakManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
q := url.Values{}
|
||||
q.Add("email", email)
|
||||
q.Add("exact", "true")
|
||||
|
||||
body, err := km.get("users", q)
|
||||
body, err := km.get(ctx, "users", q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -240,8 +241,8 @@ func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
func (km *KeycloakManager) GetUserDataByID(userID string, _ AppMetadata) (*UserData, error) {
|
||||
body, err := km.get("users/"+userID, nil)
|
||||
func (km *KeycloakManager) GetUserDataByID(ctx context.Context, userID string, _ AppMetadata) (*UserData, error) {
|
||||
body, err := km.get(ctx, "users/"+userID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -260,8 +261,8 @@ func (km *KeycloakManager) GetUserDataByID(userID string, _ AppMetadata) (*UserD
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given account profile.
|
||||
func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
profiles, err := km.fetchAllUserProfiles()
|
||||
func (km *KeycloakManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
profiles, err := km.fetchAllUserProfiles(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -283,8 +284,8 @@ func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
profiles, err := km.fetchAllUserProfiles()
|
||||
func (km *KeycloakManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
profiles, err := km.fetchAllUserProfiles(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -303,19 +304,19 @@ func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (km *KeycloakManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
||||
func (km *KeycloakManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (km *KeycloakManager) InviteUserByID(_ string) error {
|
||||
func (km *KeycloakManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Keycloak by user ID.
|
||||
func (km *KeycloakManager) DeleteUser(userID string) error {
|
||||
jwtToken, err := km.credentials.Authenticate()
|
||||
func (km *KeycloakManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
jwtToken, err := km.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -353,8 +354,8 @@ func (km *KeycloakManager) DeleteUser(userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) {
|
||||
totalUsers, err := km.totalUsersCount()
|
||||
func (km *KeycloakManager) fetchAllUserProfiles(ctx context.Context) ([]keycloakProfile, error) {
|
||||
totalUsers, err := km.totalUsersCount(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -362,7 +363,7 @@ func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) {
|
||||
q := url.Values{}
|
||||
q.Add("max", fmt.Sprint(*totalUsers))
|
||||
|
||||
body, err := km.get("users", q)
|
||||
body, err := km.get(ctx, "users", q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -377,8 +378,8 @@ func (km *KeycloakManager) fetchAllUserProfiles() ([]keycloakProfile, error) {
|
||||
}
|
||||
|
||||
// get perform Get requests.
|
||||
func (km *KeycloakManager) get(resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := km.credentials.Authenticate()
|
||||
func (km *KeycloakManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := km.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -414,8 +415,8 @@ func (km *KeycloakManager) get(resource string, q url.Values) ([]byte, error) {
|
||||
|
||||
// totalUsersCount returns the total count of all user created.
|
||||
// Used when fetching all registered accounts with pagination.
|
||||
func (km *KeycloakManager) totalUsersCount() (*int, error) {
|
||||
body, err := km.get("users/count", nil)
|
||||
func (km *KeycloakManager) totalUsersCount(ctx context.Context) (*int, error) {
|
||||
body, err := km.get(ctx, "users/count", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
@ -128,7 +129,7 @@ func TestKeycloakRequestJWTToken(t *testing.T) {
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
resp, err := creds.requestJWTToken()
|
||||
resp, err := creds.requestJWTToken(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
@ -294,7 +295,7 @@ func TestKeycloakAuthenticate(t *testing.T) {
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate()
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
|
@ -1,77 +1,79 @@
|
||||
package idp
|
||||
|
||||
import "context"
|
||||
|
||||
// MockIDP is a mock implementation of the IDP interface
|
||||
type MockIDP struct {
|
||||
UpdateUserAppMetadataFunc func(userId string, appMetadata AppMetadata) error
|
||||
GetUserDataByIDFunc func(userId string, appMetadata AppMetadata) (*UserData, error)
|
||||
GetAccountFunc func(accountId string) ([]*UserData, error)
|
||||
GetAllAccountsFunc func() (map[string][]*UserData, error)
|
||||
CreateUserFunc func(email, name, accountID, invitedByEmail string) (*UserData, error)
|
||||
GetUserByEmailFunc func(email string) ([]*UserData, error)
|
||||
InviteUserByIDFunc func(userID string) error
|
||||
DeleteUserFunc func(userID string) error
|
||||
UpdateUserAppMetadataFunc func(ctx context.Context, userId string, appMetadata AppMetadata) error
|
||||
GetUserDataByIDFunc func(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
|
||||
GetAccountFunc func(ctx context.Context, accountId string) ([]*UserData, error)
|
||||
GetAllAccountsFunc func(ctx context.Context) (map[string][]*UserData, error)
|
||||
CreateUserFunc func(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
|
||||
GetUserByEmailFunc func(ctx context.Context, email string) ([]*UserData, error)
|
||||
InviteUserByIDFunc func(ctx context.Context, userID string) error
|
||||
DeleteUserFunc func(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata is a mock implementation of the IDP interface UpdateUserAppMetadata method
|
||||
func (m *MockIDP) UpdateUserAppMetadata(userId string, appMetadata AppMetadata) error {
|
||||
func (m *MockIDP) UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error {
|
||||
if m.UpdateUserAppMetadataFunc != nil {
|
||||
return m.UpdateUserAppMetadataFunc(userId, appMetadata)
|
||||
return m.UpdateUserAppMetadataFunc(ctx, userId, appMetadata)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID is a mock implementation of the IDP interface GetUserDataByID method
|
||||
func (m *MockIDP) GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error) {
|
||||
func (m *MockIDP) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
|
||||
if m.GetUserDataByIDFunc != nil {
|
||||
return m.GetUserDataByIDFunc(userId, appMetadata)
|
||||
return m.GetUserDataByIDFunc(ctx, userId, appMetadata)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetAccount is a mock implementation of the IDP interface GetAccount method
|
||||
func (m *MockIDP) GetAccount(accountId string) ([]*UserData, error) {
|
||||
func (m *MockIDP) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
|
||||
if m.GetAccountFunc != nil {
|
||||
return m.GetAccountFunc(accountId)
|
||||
return m.GetAccountFunc(ctx, accountId)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts is a mock implementation of the IDP interface GetAllAccounts method
|
||||
func (m *MockIDP) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
func (m *MockIDP) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
if m.GetAllAccountsFunc != nil {
|
||||
return m.GetAllAccountsFunc()
|
||||
return m.GetAllAccountsFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// CreateUser is a mock implementation of the IDP interface CreateUser method
|
||||
func (m *MockIDP) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
func (m *MockIDP) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
if m.CreateUserFunc != nil {
|
||||
return m.CreateUserFunc(email, name, accountID, invitedByEmail)
|
||||
return m.CreateUserFunc(ctx, email, name, accountID, invitedByEmail)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail is a mock implementation of the IDP interface GetUserByEmail method
|
||||
func (m *MockIDP) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
func (m *MockIDP) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
if m.GetUserByEmailFunc != nil {
|
||||
return m.GetUserByEmailFunc(email)
|
||||
return m.GetUserByEmailFunc(ctx, email)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// InviteUserByID is a mock implementation of the IDP interface InviteUserByID method
|
||||
func (m *MockIDP) InviteUserByID(userID string) error {
|
||||
func (m *MockIDP) InviteUserByID(ctx context.Context, userID string) error {
|
||||
if m.InviteUserByIDFunc != nil {
|
||||
return m.InviteUserByIDFunc(userID)
|
||||
return m.InviteUserByIDFunc(ctx, userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUser is a mock implementation of the IDP interface DeleteUser method
|
||||
func (m *MockIDP) DeleteUser(userID string) error {
|
||||
func (m *MockIDP) DeleteUser(ctx context.Context, userID string) error {
|
||||
if m.DeleteUserFunc != nil {
|
||||
return m.DeleteUserFunc(userID)
|
||||
return m.DeleteUserFunc(ctx, userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -94,17 +94,17 @@ func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (*
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the okta user API.
|
||||
func (oc *OktaCredentials) Authenticate() (JWTToken, error) {
|
||||
func (oc *OktaCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in okta Idp and sends an invitation.
|
||||
func (om *OktaManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||
func (om *OktaManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
func (om *OktaManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
func (om *OktaManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
user, resp, err := om.client.User.GetUser(context.Background(), userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -132,7 +132,7 @@ func (om *OktaManager) GetUserDataByID(userID string, appMetadata AppMetadata) (
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
func (om *OktaManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
||||
user, resp, err := om.client.User.GetUser(context.Background(), url.QueryEscape(email))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -160,7 +160,7 @@ func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
func (om *OktaManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := om.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -180,7 +180,7 @@ func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
func (om *OktaManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
||||
users, err := om.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -242,18 +242,18 @@ func (om *OktaManager) getAllUsers() ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||
func (om *OktaManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (om *OktaManager) InviteUserByID(_ string) error {
|
||||
func (om *OktaManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Okta
|
||||
func (om *OktaManager) DeleteUser(userID string) error {
|
||||
func (om *OktaManager) DeleteUser(_ context.Context, userID string) error {
|
||||
resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -1,6 +1,7 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -149,7 +150,7 @@ func (zc *ZitadelCredentials) jwtStillValid() bool {
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) {
|
||||
func (zc *ZitadelCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", zc.clientConfig.ClientID)
|
||||
data.Set("client_secret", zc.clientConfig.ClientSecret)
|
||||
@ -163,7 +164,7 @@ func (zc *ZitadelCredentials) requestJWTToken() (*http.Response, error) {
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.Debug("requesting new jwt token for zitadel idp manager")
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for zitadel idp manager")
|
||||
|
||||
resp, err := zc.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@ -215,7 +216,7 @@ func (zc *ZitadelCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JW
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the Zitadel Management API.
|
||||
func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
|
||||
func (zc *ZitadelCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
zc.mux.Lock()
|
||||
defer zc.mux.Unlock()
|
||||
|
||||
@ -229,7 +230,7 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
|
||||
return zc.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := zc.requestJWTToken()
|
||||
resp, err := zc.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return zc.jwtToken, err
|
||||
}
|
||||
@ -246,7 +247,7 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in zitadel Idp and sends an invite via Zitadel.
|
||||
func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
firstLast := strings.SplitN(name, " ", 2)
|
||||
|
||||
var addUser = map[string]any{
|
||||
@ -269,7 +270,7 @@ func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail stri
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := zm.post("users/human/_import", string(payload))
|
||||
body, err := zm.post(ctx, "users/human/_import", string(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -300,7 +301,7 @@ func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail stri
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
func (zm *ZitadelManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
searchByEmail := zitadelAttributes{
|
||||
"queries": {
|
||||
{
|
||||
@ -316,7 +317,7 @@ func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := zm.post("users/_search", string(payload))
|
||||
body, err := zm.post(ctx, "users/_search", string(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -340,8 +341,8 @@ func (zm *ZitadelManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from zitadel via ID.
|
||||
func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
body, err := zm.get("users/"+userID, nil)
|
||||
func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
body, err := zm.get(ctx, "users/"+userID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -363,8 +364,8 @@ func (zm *ZitadelManager) GetUserDataByID(userID string, appMetadata AppMetadata
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
body, err := zm.post("users/_search", "")
|
||||
func (zm *ZitadelManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
body, err := zm.post(ctx, "users/_search", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -392,8 +393,8 @@ func (zm *ZitadelManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
body, err := zm.post("users/_search", "")
|
||||
func (zm *ZitadelManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
body, err := zm.post(ctx, "users/_search", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -419,7 +420,7 @@ func (zm *ZitadelManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
// Metadata values are base64 encoded.
|
||||
func (zm *ZitadelManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
||||
func (zm *ZitadelManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -429,7 +430,7 @@ type inviteUserRequest struct {
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (zm *ZitadelManager) InviteUserByID(userID string) error {
|
||||
func (zm *ZitadelManager) InviteUserByID(ctx context.Context, userID string) error {
|
||||
inviteUser := inviteUserRequest{
|
||||
Email: userID,
|
||||
}
|
||||
@ -440,14 +441,14 @@ func (zm *ZitadelManager) InviteUserByID(userID string) error {
|
||||
}
|
||||
|
||||
// don't care about the body in the response
|
||||
_, err = zm.post(fmt.Sprintf("users/%s/_resend_initialization", userID), string(payload))
|
||||
_, err = zm.post(ctx, fmt.Sprintf("users/%s/_resend_initialization", userID), string(payload))
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteUser from Zitadel
|
||||
func (zm *ZitadelManager) DeleteUser(userID string) error {
|
||||
func (zm *ZitadelManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
resource := fmt.Sprintf("users/%s", userID)
|
||||
if err := zm.delete(resource); err != nil {
|
||||
if err := zm.delete(ctx, resource); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -459,8 +460,8 @@ func (zm *ZitadelManager) DeleteUser(userID string) error {
|
||||
}
|
||||
|
||||
// post perform Post requests.
|
||||
func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
|
||||
jwtToken, err := zm.credentials.Authenticate()
|
||||
func (zm *ZitadelManager) post(ctx context.Context, resource string, body string) ([]byte, error) {
|
||||
jwtToken, err := zm.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -495,8 +496,8 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
|
||||
}
|
||||
|
||||
// delete perform Delete requests.
|
||||
func (zm *ZitadelManager) delete(resource string) error {
|
||||
jwtToken, err := zm.credentials.Authenticate()
|
||||
func (zm *ZitadelManager) delete(ctx context.Context, resource string) error {
|
||||
jwtToken, err := zm.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -531,8 +532,8 @@ func (zm *ZitadelManager) delete(resource string) error {
|
||||
}
|
||||
|
||||
// get perform Get requests.
|
||||
func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := zm.credentials.Authenticate()
|
||||
func (zm *ZitadelManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := zm.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
@ -108,7 +109,7 @@ func TestZitadelRequestJWTToken(t *testing.T) {
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
resp, err := creds.requestJWTToken()
|
||||
resp, err := creds.requestJWTToken(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
@ -274,7 +275,7 @@ func TestZitadelAuthenticate(t *testing.T) {
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate()
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
|
@ -1,9 +1,10 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/google/martian/v3/log"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
)
|
||||
@ -19,22 +20,22 @@ import (
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if any occurred during the process, otherwise returns nil
|
||||
func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error {
|
||||
ok, err := am.GroupValidation(accountID, groups)
|
||||
func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error {
|
||||
ok, err := am.GroupValidation(ctx, accountID, groups)
|
||||
if err != nil {
|
||||
log.Debugf("error validating groups: %s", err.Error())
|
||||
log.WithContext(ctx).Debugf("error validating groups: %s", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if !ok {
|
||||
log.Debugf("invalid groups")
|
||||
log.WithContext(ctx).Debugf("invalid groups")
|
||||
return errors.New("invalid groups")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
a, err := am.Store.GetAccountByUser(userID)
|
||||
a, err := am.Store.GetAccountByUser(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -48,14 +49,14 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID strin
|
||||
a.Settings.Extra = extra
|
||||
}
|
||||
extra.IntegratedValidatorGroups = groups
|
||||
return am.Store.SaveAccount(a)
|
||||
return am.Store.SaveAccount(ctx, a)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GroupValidation(accountId string, groups []string) (bool, error) {
|
||||
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) {
|
||||
if len(groups) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
accountsGroups, err := am.ListGroups(accountId)
|
||||
accountsGroups, err := am.ListGroups(ctx, accountId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
package integrated_validator
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@ -8,12 +10,12 @@ import (
|
||||
|
||||
// IntegratedValidator interface exists to avoid the circle dependencies
|
||||
type IntegratedValidator interface {
|
||||
ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
|
||||
ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error)
|
||||
PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
|
||||
IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
|
||||
ValidateExtraSettings(ctx context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
|
||||
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error)
|
||||
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer
|
||||
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error)
|
||||
GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error)
|
||||
PeerDeleted(accountID, peerID string) error
|
||||
PeerDeleted(ctx context.Context, accountID, peerID string) error
|
||||
SetPeerInvalidationListener(fn func(accountID string))
|
||||
Stop()
|
||||
Stop(ctx context.Context)
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package jwtclaims
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
@ -69,8 +70,8 @@ type JWTValidator struct {
|
||||
}
|
||||
|
||||
// NewJWTValidator constructor
|
||||
func NewJWTValidator(issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) {
|
||||
keys, err := getPemKeys(keysLocation)
|
||||
func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) {
|
||||
keys, err := getPemKeys(ctx, keysLocation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -102,19 +103,19 @@ func NewJWTValidator(issuer string, audienceList []string, keysLocation string,
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
refreshedKeys, err := getPemKeys(keysLocation)
|
||||
refreshedKeys, err := getPemKeys(ctx, keysLocation)
|
||||
if err != nil {
|
||||
log.Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
|
||||
log.WithContext(ctx).Debugf("cannot get JSONWebKey: %v, falling back to old keys", err)
|
||||
refreshedKeys = keys
|
||||
}
|
||||
|
||||
log.Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
|
||||
log.WithContext(ctx).Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
|
||||
|
||||
keys = refreshedKeys
|
||||
}
|
||||
}
|
||||
|
||||
cert, err := getPemCert(token, keys)
|
||||
cert, err := getPemCert(ctx, token, keys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -136,19 +137,19 @@ func NewJWTValidator(issuer string, audienceList []string, keysLocation string,
|
||||
}
|
||||
|
||||
// ValidateAndParse validates the token and returns the parsed token
|
||||
func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
|
||||
func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
|
||||
// If the token is empty...
|
||||
if token == "" {
|
||||
// Check if it was required
|
||||
if m.options.CredentialsOptional {
|
||||
log.Debugf("no credentials found (CredentialsOptional=true)")
|
||||
log.WithContext(ctx).Debugf("no credentials found (CredentialsOptional=true)")
|
||||
// No error, just no token (and that is ok given that CredentialsOptional is true)
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
// If we get here, the required token is missing
|
||||
errorMsg := "required authorization token not found"
|
||||
log.Debugf(" Error: No credentials found (CredentialsOptional=false)")
|
||||
log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)")
|
||||
return nil, fmt.Errorf(errorMsg)
|
||||
}
|
||||
|
||||
@ -157,7 +158,7 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
|
||||
|
||||
// Check if there was an error in parsing...
|
||||
if err != nil {
|
||||
log.Errorf("error parsing token: %v", err)
|
||||
log.WithContext(ctx).Errorf("error parsing token: %v", err)
|
||||
return nil, fmt.Errorf("Error parsing token: %w", err)
|
||||
}
|
||||
|
||||
@ -165,14 +166,14 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
|
||||
errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s",
|
||||
m.options.SigningMethod.Alg(),
|
||||
parsedToken.Header["alg"])
|
||||
log.Debugf("error validating token algorithm: %s", errorMsg)
|
||||
log.WithContext(ctx).Debugf("error validating token algorithm: %s", errorMsg)
|
||||
return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg)
|
||||
}
|
||||
|
||||
// Check if the parsed token is valid...
|
||||
if !parsedToken.Valid {
|
||||
errorMsg := "token is invalid"
|
||||
log.Debugf(errorMsg)
|
||||
log.WithContext(ctx).Debugf(errorMsg)
|
||||
return nil, errors.New(errorMsg)
|
||||
}
|
||||
|
||||
@ -184,7 +185,7 @@ func (jwks *Jwks) stillValid() bool {
|
||||
return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
|
||||
}
|
||||
|
||||
func getPemKeys(keysLocation string) (*Jwks, error) {
|
||||
func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) {
|
||||
resp, err := http.Get(keysLocation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -198,13 +199,13 @@ func getPemKeys(keysLocation string) (*Jwks, error) {
|
||||
}
|
||||
|
||||
cacheControlHeader := resp.Header.Get("Cache-Control")
|
||||
expiresIn := getMaxAgeFromCacheHeader(cacheControlHeader)
|
||||
expiresIn := getMaxAgeFromCacheHeader(ctx, cacheControlHeader)
|
||||
jwks.expiresInTime = time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||
|
||||
return jwks, err
|
||||
}
|
||||
|
||||
func getPemCert(token *jwt.Token, jwks *Jwks) (string, error) {
|
||||
func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, error) {
|
||||
// todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time
|
||||
cert := ""
|
||||
|
||||
@ -217,7 +218,7 @@ func getPemCert(token *jwt.Token, jwks *Jwks) (string, error) {
|
||||
cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
|
||||
return cert, nil
|
||||
}
|
||||
log.Debugf("generating validation pem from JWK")
|
||||
log.WithContext(ctx).Debugf("generating validation pem from JWK")
|
||||
return generatePemFromJWK(jwks.Keys[k])
|
||||
}
|
||||
|
||||
@ -284,7 +285,7 @@ func convertExponentStringToInt(stringExponent string) (int, error) {
|
||||
}
|
||||
|
||||
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
|
||||
func getMaxAgeFromCacheHeader(cacheControl string) int {
|
||||
func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
|
||||
// Split into individual directives
|
||||
directives := strings.Split(cacheControl, ",")
|
||||
|
||||
@ -295,7 +296,7 @@ func getMaxAgeFromCacheHeader(cacheControl string) int {
|
||||
maxAgeStr := strings.TrimPrefix(directive, "max-age=")
|
||||
maxAge, err := strconv.Atoi(maxAgeStr)
|
||||
if err != nil {
|
||||
log.Debugf("error parsing max-age: %v", err)
|
||||
log.WithContext(ctx).Debugf("error parsing max-age: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
|
@ -406,7 +406,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
|
||||
return nil, "", err
|
||||
}
|
||||
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
|
||||
store, cleanUp, err := NewTestStoreFromJson(config.Datadir)
|
||||
store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
@ -414,7 +414,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
|
||||
|
||||
peersUpdateManager := NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
accountManager, err := BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
eventStore, nil, false, MocIntegratedValidator{})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
@ -422,7 +422,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
|
||||
turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
|
||||
ephemeralMgr := NewEphemeralManager(store, accountManager)
|
||||
mgmtServer, err := NewServer(config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)
|
||||
mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
@ -451,11 +451,11 @@ var _ = Describe("Management service", func() {
|
||||
type MocIntegratedValidator struct {
|
||||
}
|
||||
|
||||
func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
||||
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
|
||||
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) {
|
||||
return update, nil
|
||||
}
|
||||
|
||||
@ -467,15 +467,15 @@ func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[s
|
||||
return validatedPeers, nil
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
|
||||
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
|
||||
return peer
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
|
||||
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
|
||||
return false, false, nil
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) PeerDeleted(_, _ string) error {
|
||||
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -483,7 +483,7 @@ func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)
|
||||
|
||||
}
|
||||
|
||||
func (MocIntegratedValidator) Stop() {}
|
||||
func (MocIntegratedValidator) Stop(_ context.Context) {}
|
||||
|
||||
func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse {
|
||||
defer GinkgoRecover()
|
||||
@ -534,20 +534,20 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
s := grpc.NewServer()
|
||||
|
||||
store, _, err := server.NewTestStoreFromJson(config.Datadir)
|
||||
store, _, err := server.NewTestStoreFromJson(context.Background(), config.Datadir)
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating a store: %s: %v", config.Datadir, err)
|
||||
}
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
eventStore, nil, false, MocIntegratedValidator{})
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating a manager: %v", err)
|
||||
}
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
go func() {
|
||||
|
@ -46,7 +46,7 @@ type properties map[string]interface{}
|
||||
|
||||
// DataSource metric data source
|
||||
type DataSource interface {
|
||||
GetAllAccounts() []*server.Account
|
||||
GetAllAccounts(ctx context.Context) []*server.Account
|
||||
GetStoreEngine() server.StoreEngine
|
||||
}
|
||||
|
||||
@ -81,29 +81,29 @@ func NewWorker(ctx context.Context, id string, dataSource DataSource, connManage
|
||||
}
|
||||
|
||||
// Run runs the metrics worker
|
||||
func (w *Worker) Run() {
|
||||
func (w *Worker) Run(ctx context.Context) {
|
||||
pushTicker := time.NewTicker(defaultPushInterval)
|
||||
for {
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
return
|
||||
case <-pushTicker.C:
|
||||
err := w.sendMetrics()
|
||||
err := w.sendMetrics(ctx)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.WithContext(ctx).Error(err)
|
||||
}
|
||||
w.lastRun = time.Now()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) sendMetrics() error {
|
||||
func (w *Worker) sendMetrics(ctx context.Context) error {
|
||||
apiKey, err := getAPIKey(w.ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
payload := w.generatePayload(apiKey)
|
||||
payload := w.generatePayload(ctx, apiKey)
|
||||
|
||||
payloadString, err := buildMetricsPayload(payload)
|
||||
if err != nil {
|
||||
@ -125,7 +125,7 @@ func (w *Worker) sendMetrics() error {
|
||||
defer func() {
|
||||
err = jobResp.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing update metrics response body: %v", err)
|
||||
log.WithContext(ctx).Errorf("error while closing update metrics response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -133,15 +133,15 @@ func (w *Worker) sendMetrics() error {
|
||||
return fmt.Errorf("unable to push anonymous metrics, got statusCode %d", jobResp.StatusCode)
|
||||
}
|
||||
|
||||
log.Infof("sent anonymous metrics, next push will happen in %s. "+
|
||||
log.WithContext(ctx).Infof("sent anonymous metrics, next push will happen in %s. "+
|
||||
"You can disable these metrics by running with flag --disable-anonymous-metrics,"+
|
||||
" see more information at https://netbird.io/docs/FAQ/metrics-collection", defaultPushInterval)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Worker) generatePayload(apiKey string) pushPayload {
|
||||
properties := w.generateProperties()
|
||||
func (w *Worker) generatePayload(ctx context.Context, apiKey string) pushPayload {
|
||||
properties := w.generateProperties(ctx)
|
||||
|
||||
return pushPayload{
|
||||
APIKey: apiKey,
|
||||
@ -152,7 +152,7 @@ func (w *Worker) generatePayload(apiKey string) pushPayload {
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) generateProperties() properties {
|
||||
func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
var (
|
||||
uptime float64
|
||||
accounts int
|
||||
@ -192,7 +192,7 @@ func (w *Worker) generateProperties() properties {
|
||||
connections := w.connManager.GetAllConnectedPeers()
|
||||
version = nbversion.NetbirdVersion()
|
||||
|
||||
for _, account := range w.dataSource.GetAllAccounts() {
|
||||
for _, account := range w.dataSource.GetAllAccounts(ctx) {
|
||||
accounts++
|
||||
|
||||
if account.Settings.PeerLoginExpirationEnabled {
|
||||
@ -342,7 +342,7 @@ func getAPIKey(ctx context.Context) (string, error) {
|
||||
defer func() {
|
||||
err = response.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing metrics token response body: %v", err)
|
||||
log.WithContext(ctx).Errorf("error while closing metrics token response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@ -21,7 +22,7 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} {
|
||||
}
|
||||
|
||||
// GetAllAccounts returns a list of *server.Account for use in tests with predefined information
|
||||
func (mockDatasource) GetAllAccounts() []*server.Account {
|
||||
func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account {
|
||||
return []*server.Account{
|
||||
{
|
||||
Id: "1",
|
||||
@ -188,7 +189,7 @@ func TestGenerateProperties(t *testing.T) {
|
||||
connManager: ds,
|
||||
}
|
||||
|
||||
properties := worker.generateProperties()
|
||||
properties := worker.generateProperties(context.Background())
|
||||
|
||||
if properties["accounts"] != 2 {
|
||||
t.Errorf("expected 2 accounts, got %d", properties["accounts"])
|
||||
|
@ -1,6 +1,7 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
@ -16,7 +17,7 @@ import (
|
||||
// MigrateFieldFromGobToJSON migrates a column from Gob encoding to JSON encoding.
|
||||
// T is the type of the model that contains the field to be migrated.
|
||||
// S is the type of the field to be migrated.
|
||||
func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) error {
|
||||
func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, fieldName string) error {
|
||||
|
||||
oldColumnName := fieldName
|
||||
newColumnName := fieldName + "_tmp"
|
||||
@ -24,7 +25,7 @@ func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) erro
|
||||
var model T
|
||||
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.Debugf("Table for %T does not exist, no migration needed", model)
|
||||
log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -38,7 +39,7 @@ func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) erro
|
||||
var sqliteItem sql.NullString
|
||||
if err := db.Model(model).Select(oldColumnName).First(&sqliteItem).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Debugf("No records in table %s, no migration needed", tableName)
|
||||
log.WithContext(ctx).Debugf("No records in table %s, no migration needed", tableName)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("fetch first record: %w", err)
|
||||
@ -51,7 +52,7 @@ func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) erro
|
||||
err = json.Unmarshal([]byte(item), &js)
|
||||
// if the item is JSON parsable or an empty string it can not be gob encoded
|
||||
if err == nil || !errors.As(err, &syntaxError) || item == "" {
|
||||
log.Debugf("No migration needed for %s, %s", tableName, fieldName)
|
||||
log.WithContext(ctx).Debugf("No migration needed for %s, %s", tableName, fieldName)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -99,14 +100,14 @@ func MigrateFieldFromGobToJSON[T any, S any](db *gorm.DB, fieldName string) erro
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("Migration of %s.%s from gob to json completed", tableName, fieldName)
|
||||
log.WithContext(ctx).Infof("Migration of %s.%s from gob to json completed", tableName, fieldName)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MigrateNetIPFieldFromBlobToJSON migrates a Net IP column from Blob encoding to JSON encoding.
|
||||
// T is the type of the model that contains the field to be migrated.
|
||||
func MigrateNetIPFieldFromBlobToJSON[T any](db *gorm.DB, fieldName string, indexName string) error {
|
||||
func MigrateNetIPFieldFromBlobToJSON[T any](ctx context.Context, db *gorm.DB, fieldName string, indexName string) error {
|
||||
oldColumnName := fieldName
|
||||
newColumnName := fieldName + "_tmp"
|
||||
|
||||
@ -138,7 +139,7 @@ func MigrateNetIPFieldFromBlobToJSON[T any](db *gorm.DB, fieldName string, index
|
||||
var syntaxError *json.SyntaxError
|
||||
err = json.Unmarshal([]byte(item.String), &js)
|
||||
if err == nil || !errors.As(err, &syntaxError) {
|
||||
log.Debugf("No migration needed for %s, %s", tableName, fieldName)
|
||||
log.WithContext(ctx).Debugf("No migration needed for %s, %s", tableName, fieldName)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@ -169,7 +170,7 @@ func MigrateNetIPFieldFromBlobToJSON[T any](db *gorm.DB, fieldName string, index
|
||||
|
||||
columnIpValue := net.IP(blobValue)
|
||||
if net.ParseIP(columnIpValue.String()) == nil {
|
||||
log.Debugf("failed to parse %s as ip, fallback to ipv6 loopback", oldColumnName)
|
||||
log.WithContext(ctx).Debugf("failed to parse %s as ip, fallback to ipv6 loopback", oldColumnName)
|
||||
columnIpValue = net.IPv6loopback
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package migration_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"net"
|
||||
"strings"
|
||||
@ -30,7 +31,7 @@ func setupDatabase(t *testing.T) *gorm.DB {
|
||||
|
||||
func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net")
|
||||
err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net")
|
||||
require.NoError(t, err, "Migration should not fail for an empty database")
|
||||
}
|
||||
|
||||
@ -63,7 +64,7 @@ func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) {
|
||||
err = gob.NewDecoder(strings.NewReader(gobStr)).Decode(&ipnet)
|
||||
require.NoError(t, err, "Failed to decode Gob data")
|
||||
|
||||
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net")
|
||||
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net")
|
||||
require.NoError(t, err, "Migration should not fail with Gob data")
|
||||
|
||||
var jsonStr string
|
||||
@ -83,7 +84,7 @@ func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) {
|
||||
err = db.Save(&server.Account{Network: &server.Network{Net: *ipnet}}).Error
|
||||
require.NoError(t, err, "Failed to insert JSON data")
|
||||
|
||||
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](db, "network_net")
|
||||
err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net")
|
||||
require.NoError(t, err, "Migration should not fail with JSON data")
|
||||
|
||||
var jsonStr string
|
||||
@ -93,7 +94,7 @@ func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) {
|
||||
|
||||
func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
err := migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "ip", "idx_peers_account_id_ip")
|
||||
err := migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "ip", "idx_peers_account_id_ip")
|
||||
require.NoError(t, err, "Migration should not fail for an empty database")
|
||||
}
|
||||
|
||||
@ -130,7 +131,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) {
|
||||
err = db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&blobValue).Error
|
||||
assert.NoError(t, err, "Failed to fetch blob data")
|
||||
|
||||
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
|
||||
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "")
|
||||
require.NoError(t, err, "Migration should not fail with net.IP blob data")
|
||||
|
||||
var jsonStr string
|
||||
@ -152,7 +153,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) {
|
||||
).Error
|
||||
require.NoError(t, err, "Failed to insert JSON data")
|
||||
|
||||
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](db, "location_connection_ip", "")
|
||||
err = migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](context.Background(), db, "location_connection_ip", "")
|
||||
require.NoError(t, err, "Migration should not fail with net.IP JSON data")
|
||||
|
||||
var jsonStr string
|
||||
|
@ -1,6 +1,7 @@
|
||||
package mock_server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
@ -21,94 +22,95 @@ import (
|
||||
)
|
||||
|
||||
type MockAccountManager struct {
|
||||
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
|
||||
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType,
|
||||
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error)
|
||||
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
|
||||
GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error)
|
||||
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
|
||||
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
ListUsersFunc func(accountID string) ([]*server.User, error)
|
||||
GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error
|
||||
SyncAndMarkPeerFunc func(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
DeletePeerFunc func(accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
|
||||
AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error)
|
||||
GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error)
|
||||
GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error)
|
||||
SaveGroupFunc func(accountID, userID string, group *group.Group) error
|
||||
DeleteGroupFunc func(accountID, userId, groupID string) error
|
||||
ListGroupsFunc func(accountID string) ([]*group.Group, error)
|
||||
GroupAddPeerFunc func(accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
|
||||
DeleteRuleFunc func(accountID, ruleID, userID string) error
|
||||
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
|
||||
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
|
||||
DeletePolicyFunc func(accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error)
|
||||
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
|
||||
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
MarkPATUsedFunc func(pat string) error
|
||||
UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
|
||||
UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
CreateRouteFunc func(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||
GetRouteFunc func(accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||
SaveRouteFunc func(accountID string, userID string, route *route.Route) error
|
||||
DeleteRouteFunc func(accountID string, routeID route.ID, userID string) error
|
||||
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
|
||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
||||
ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error)
|
||||
SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error)
|
||||
SaveOrAddUserFunc func(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
|
||||
DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error
|
||||
CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||
GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
||||
GetNameServerGroupFunc func(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
|
||||
ListNameServerGroupsFunc func(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
||||
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
|
||||
DeleteAccountFunc func(accountID, userID string) error
|
||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
|
||||
GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error)
|
||||
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error)
|
||||
AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*group.Group, error)
|
||||
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error)
|
||||
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error)
|
||||
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error
|
||||
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
||||
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
|
||||
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
|
||||
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
|
||||
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) error
|
||||
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
|
||||
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
|
||||
GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
|
||||
MarkPATUsedFunc func(ctx context.Context, pat string) error
|
||||
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error
|
||||
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
|
||||
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
|
||||
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
|
||||
ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error)
|
||||
SaveSetupKeyFunc func(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
||||
ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error)
|
||||
SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error)
|
||||
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
|
||||
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||
GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error)
|
||||
GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
|
||||
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
|
||||
CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||
GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
|
||||
CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
|
||||
GetDNSDomainFunc func() string
|
||||
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
|
||||
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
|
||||
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
|
||||
GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
|
||||
LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
SyncPeerFunc func(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
|
||||
GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error)
|
||||
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error)
|
||||
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
|
||||
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error)
|
||||
LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
|
||||
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
GetAllConnectedPeersFunc func() (map[string]struct{}, error)
|
||||
HasConnectedChannelFunc func(peerID string) bool
|
||||
GetExternalCacheManagerFunc func() server.ExternalCacheManager
|
||||
GetPostureChecksFunc func(accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
|
||||
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
|
||||
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
|
||||
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
|
||||
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error
|
||||
GroupValidationFunc func(accountId string, groups []string) (bool, error)
|
||||
SyncPeerMetaFunc func(peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
UpdateIntegratedValidatorGroupsFunc func(ctx context.Context, accountID string, userID string, groups []string) error
|
||||
GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
if am.SyncAndMarkPeerFunc != nil {
|
||||
return am.SyncAndMarkPeerFunc(peerPubKey, meta, realIP)
|
||||
return am.SyncAndMarkPeerFunc(ctx, peerPubKey, meta, realIP)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
|
||||
func (am *MockAccountManager) CancelPeerRoutines(_ context.Context, peer *nbpeer.Peer) error {
|
||||
// TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
@ -122,43 +124,43 @@ func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[st
|
||||
}
|
||||
|
||||
// GetGroup mock implementation of GetGroup from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*group.Group, error) {
|
||||
func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, userID string) (*group.Group, error) {
|
||||
if am.GetGroupFunc != nil {
|
||||
return am.GetGroupFunc(accountId, groupID, userID)
|
||||
return am.GetGroupFunc(ctx, accountId, groupID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetGroup is not implemented")
|
||||
}
|
||||
|
||||
// GetAllGroups mock implementation of GetAllGroups from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAllGroups(accountID, userID string) ([]*group.Group, error) {
|
||||
func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*group.Group, error) {
|
||||
if am.GetAllGroupsFunc != nil {
|
||||
return am.GetAllGroupsFunc(accountID, userID)
|
||||
return am.GetAllGroupsFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAllGroups is not implemented")
|
||||
}
|
||||
|
||||
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetUsersFromAccount(accountID string, userID string) ([]*server.UserInfo, error) {
|
||||
func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*server.UserInfo, error) {
|
||||
if am.GetUsersFromAccountFunc != nil {
|
||||
return am.GetUsersFromAccountFunc(accountID, userID)
|
||||
return am.GetUsersFromAccountFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount is not implemented")
|
||||
}
|
||||
|
||||
// DeletePeer mock implementation of DeletePeer from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeletePeer(accountID, peerID, userID string) error {
|
||||
func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
|
||||
if am.DeletePeerFunc != nil {
|
||||
return am.DeletePeerFunc(accountID, peerID, userID)
|
||||
return am.DeletePeerFunc(ctx, accountID, peerID, userID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented")
|
||||
}
|
||||
|
||||
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetOrCreateAccountByUser(
|
||||
userId, domain string,
|
||||
ctx context.Context, userId, domain string,
|
||||
) (*server.Account, error) {
|
||||
if am.GetOrCreateAccountByUserFunc != nil {
|
||||
return am.GetOrCreateAccountByUserFunc(userId, domain)
|
||||
return am.GetOrCreateAccountByUserFunc(ctx, userId, domain)
|
||||
}
|
||||
return nil, status.Errorf(
|
||||
codes.Unimplemented,
|
||||
@ -168,6 +170,7 @@ func (am *MockAccountManager) GetOrCreateAccountByUser(
|
||||
|
||||
// CreateSetupKey mock implementation of CreateSetupKey from server.AccountManager interface
|
||||
func (am *MockAccountManager) CreateSetupKey(
|
||||
ctx context.Context,
|
||||
accountID string,
|
||||
keyName string,
|
||||
keyType server.SetupKeyType,
|
||||
@ -178,17 +181,17 @@ func (am *MockAccountManager) CreateSetupKey(
|
||||
ephemeral bool,
|
||||
) (*server.SetupKey, error) {
|
||||
if am.CreateSetupKeyFunc != nil {
|
||||
return am.CreateSetupKeyFunc(accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID, ephemeral)
|
||||
return am.CreateSetupKeyFunc(ctx, accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID, ephemeral)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountByUserOrAccountID(
|
||||
userId, accountId, domain string,
|
||||
ctx context.Context, userId, accountId, domain string,
|
||||
) (*server.Account, error) {
|
||||
if am.GetAccountByUserOrAccountIdFunc != nil {
|
||||
return am.GetAccountByUserOrAccountIdFunc(userId, accountId, domain)
|
||||
return am.GetAccountByUserOrAccountIdFunc(ctx, userId, accountId, domain)
|
||||
}
|
||||
return nil, status.Errorf(
|
||||
codes.Unimplemented,
|
||||
@ -197,391 +200,392 @@ func (am *MockAccountManager) GetAccountByUserOrAccountID(
|
||||
}
|
||||
|
||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *server.Account) error {
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *server.Account) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
return am.MarkPeerConnectedFunc(peerKey, connected, realIP)
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
||||
if am.GetAccountFromPATFunc != nil {
|
||||
return am.GetAccountFromPATFunc(pat)
|
||||
return am.GetAccountFromPATFunc(ctx, pat)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented")
|
||||
}
|
||||
|
||||
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeleteAccount(accountID, userID string) error {
|
||||
func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
|
||||
if am.DeleteAccountFunc != nil {
|
||||
return am.DeleteAccountFunc(accountID, userID)
|
||||
return am.DeleteAccountFunc(ctx, accountID, userID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteAccount is not implemented")
|
||||
}
|
||||
|
||||
// MarkPATUsed mock implementation of MarkPATUsed from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPATUsed(pat string) error {
|
||||
func (am *MockAccountManager) MarkPATUsed(ctx context.Context, pat string) error {
|
||||
if am.MarkPATUsedFunc != nil {
|
||||
return am.MarkPATUsedFunc(pat)
|
||||
return am.MarkPATUsedFunc(ctx, pat)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPATUsed is not implemented")
|
||||
}
|
||||
|
||||
// CreatePAT mock implementation of GetPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
if am.CreatePATFunc != nil {
|
||||
return am.CreatePATFunc(accountID, initiatorUserID, targetUserID, name, expiresIn)
|
||||
return am.CreatePATFunc(ctx, accountID, initiatorUserID, targetUserID, name, expiresIn)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented")
|
||||
}
|
||||
|
||||
// DeletePAT mock implementation of DeletePAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
func (am *MockAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
if am.DeletePATFunc != nil {
|
||||
return am.DeletePATFunc(accountID, initiatorUserID, targetUserID, tokenID)
|
||||
return am.DeletePATFunc(ctx, accountID, initiatorUserID, targetUserID, tokenID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented")
|
||||
}
|
||||
|
||||
// GetPAT mock implementation of GetPAT from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
if am.GetPATFunc != nil {
|
||||
return am.GetPATFunc(accountID, initiatorUserID, targetUserID, tokenID)
|
||||
return am.GetPATFunc(ctx, accountID, initiatorUserID, targetUserID, tokenID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented")
|
||||
}
|
||||
|
||||
// GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
if am.GetAllPATsFunc != nil {
|
||||
return am.GetAllPATsFunc(accountID, initiatorUserID, targetUserID)
|
||||
return am.GetAllPATsFunc(ctx, accountID, initiatorUserID, targetUserID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented")
|
||||
}
|
||||
|
||||
// GetNetworkMap mock implementation of GetNetworkMap from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetNetworkMap(peerKey string) (*server.NetworkMap, error) {
|
||||
func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) (*server.NetworkMap, error) {
|
||||
if am.GetNetworkMapFunc != nil {
|
||||
return am.GetNetworkMapFunc(peerKey)
|
||||
return am.GetNetworkMapFunc(ctx, peerKey)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetNetworkMap is not implemented")
|
||||
}
|
||||
|
||||
// GetPeerNetwork mock implementation of GetPeerNetwork from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetPeerNetwork(peerKey string) (*server.Network, error) {
|
||||
func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string) (*server.Network, error) {
|
||||
if am.GetPeerNetworkFunc != nil {
|
||||
return am.GetPeerNetworkFunc(peerKey)
|
||||
return am.GetPeerNetworkFunc(ctx, peerKey)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerNetwork is not implemented")
|
||||
}
|
||||
|
||||
// AddPeer mock implementation of AddPeer from server.AccountManager interface
|
||||
func (am *MockAccountManager) AddPeer(
|
||||
ctx context.Context,
|
||||
setupKey string,
|
||||
userId string,
|
||||
peer *nbpeer.Peer,
|
||||
) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
if am.AddPeerFunc != nil {
|
||||
return am.AddPeerFunc(setupKey, userId, peer)
|
||||
return am.AddPeerFunc(ctx, setupKey, userId, peer)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented")
|
||||
}
|
||||
|
||||
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*group.Group, error) {
|
||||
func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*group.Group, error) {
|
||||
if am.GetGroupFunc != nil {
|
||||
return am.GetGroupByNameFunc(accountID, groupName)
|
||||
return am.GetGroupByNameFunc(ctx, accountID, groupName)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented")
|
||||
}
|
||||
|
||||
// SaveGroup mock implementation of SaveGroup from server.AccountManager interface
|
||||
func (am *MockAccountManager) SaveGroup(accountID, userID string, group *group.Group) error {
|
||||
func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID string, group *group.Group) error {
|
||||
if am.SaveGroupFunc != nil {
|
||||
return am.SaveGroupFunc(accountID, userID, group)
|
||||
return am.SaveGroupFunc(ctx, accountID, userID, group)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented")
|
||||
}
|
||||
|
||||
// DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
||||
func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
|
||||
if am.DeleteGroupFunc != nil {
|
||||
return am.DeleteGroupFunc(accountId, userId, groupID)
|
||||
return am.DeleteGroupFunc(ctx, accountId, userId, groupID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteGroup is not implemented")
|
||||
}
|
||||
|
||||
// ListGroups mock implementation of ListGroups from server.AccountManager interface
|
||||
func (am *MockAccountManager) ListGroups(accountID string) ([]*group.Group, error) {
|
||||
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
|
||||
if am.ListGroupsFunc != nil {
|
||||
return am.ListGroupsFunc(accountID)
|
||||
return am.ListGroupsFunc(ctx, accountID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
|
||||
}
|
||||
|
||||
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
|
||||
func (am *MockAccountManager) GroupAddPeer(accountID, groupID, peerID string) error {
|
||||
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
if am.GroupAddPeerFunc != nil {
|
||||
return am.GroupAddPeerFunc(accountID, groupID, peerID)
|
||||
return am.GroupAddPeerFunc(ctx, accountID, groupID, peerID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method GroupAddPeer is not implemented")
|
||||
}
|
||||
|
||||
// GroupDeletePeer mock implementation of GroupDeletePeer from server.AccountManager interface
|
||||
func (am *MockAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error {
|
||||
func (am *MockAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
if am.GroupDeletePeerFunc != nil {
|
||||
return am.GroupDeletePeerFunc(accountID, groupID, peerID)
|
||||
return am.GroupDeletePeerFunc(ctx, accountID, groupID, peerID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method GroupDeletePeer is not implemented")
|
||||
}
|
||||
|
||||
// DeleteRule mock implementation of DeleteRule from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeleteRule(accountID, ruleID, userID string) error {
|
||||
func (am *MockAccountManager) DeleteRule(ctx context.Context, accountID, ruleID, userID string) error {
|
||||
if am.DeleteRuleFunc != nil {
|
||||
return am.DeleteRuleFunc(accountID, ruleID, userID)
|
||||
return am.DeleteRuleFunc(ctx, accountID, ruleID, userID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteRule is not implemented")
|
||||
}
|
||||
|
||||
// GetPolicy mock implementation of GetPolicy from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetPolicy(accountID, policyID, userID string) (*server.Policy, error) {
|
||||
func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) {
|
||||
if am.GetPolicyFunc != nil {
|
||||
return am.GetPolicyFunc(accountID, policyID, userID)
|
||||
return am.GetPolicyFunc(ctx, accountID, policyID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPolicy is not implemented")
|
||||
}
|
||||
|
||||
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
|
||||
func (am *MockAccountManager) SavePolicy(accountID, userID string, policy *server.Policy) error {
|
||||
func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) error {
|
||||
if am.SavePolicyFunc != nil {
|
||||
return am.SavePolicyFunc(accountID, userID, policy)
|
||||
return am.SavePolicyFunc(ctx, accountID, userID, policy)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
|
||||
}
|
||||
|
||||
// DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeletePolicy(accountID, policyID, userID string) error {
|
||||
func (am *MockAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
|
||||
if am.DeletePolicyFunc != nil {
|
||||
return am.DeletePolicyFunc(accountID, policyID, userID)
|
||||
return am.DeletePolicyFunc(ctx, accountID, policyID, userID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeletePolicy is not implemented")
|
||||
}
|
||||
|
||||
// ListPolicies mock implementation of ListPolicies from server.AccountManager interface
|
||||
func (am *MockAccountManager) ListPolicies(accountID, userID string) ([]*server.Policy, error) {
|
||||
func (am *MockAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*server.Policy, error) {
|
||||
if am.ListPoliciesFunc != nil {
|
||||
return am.ListPoliciesFunc(accountID, userID)
|
||||
return am.ListPoliciesFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListPolicies is not implemented")
|
||||
}
|
||||
|
||||
// UpdatePeerMeta mock implementation of UpdatePeerMeta from server.AccountManager interface
|
||||
func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta nbpeer.PeerSystemMeta) error {
|
||||
func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error {
|
||||
if am.UpdatePeerMetaFunc != nil {
|
||||
return am.UpdatePeerMetaFunc(peerID, meta)
|
||||
return am.UpdatePeerMetaFunc(ctx, peerID, meta)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method UpdatePeerMeta is not implemented")
|
||||
}
|
||||
|
||||
// GetUser mock implementation of GetUser from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*server.User, error) {
|
||||
func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) {
|
||||
if am.GetUserFunc != nil {
|
||||
return am.GetUserFunc(claims)
|
||||
return am.GetUserFunc(ctx, claims)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) ListUsers(accountID string) ([]*server.User, error) {
|
||||
func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*server.User, error) {
|
||||
if am.ListUsersFunc != nil {
|
||||
return am.ListUsersFunc(accountID)
|
||||
return am.ListUsersFunc(ctx, accountID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented")
|
||||
}
|
||||
|
||||
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
|
||||
func (am *MockAccountManager) UpdatePeerSSHKey(peerID string, sshKey string) error {
|
||||
func (am *MockAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
|
||||
if am.UpdatePeerSSHKeyFunc != nil {
|
||||
return am.UpdatePeerSSHKeyFunc(peerID, sshKey)
|
||||
return am.UpdatePeerSSHKeyFunc(ctx, peerID, sshKey)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method UpdatePeerSSHKey is not implemented")
|
||||
}
|
||||
|
||||
// UpdatePeer mocks UpdatePeerFunc function of the account manager
|
||||
func (am *MockAccountManager) UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
if am.UpdatePeerFunc != nil {
|
||||
return am.UpdatePeerFunc(accountID, userID, peer)
|
||||
return am.UpdatePeerFunc(ctx, accountID, userID, peer)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method UpdatePeer is not implemented")
|
||||
}
|
||||
|
||||
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
|
||||
func (am *MockAccountManager) CreateRoute(accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
||||
func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
|
||||
if am.CreateRouteFunc != nil {
|
||||
return am.CreateRouteFunc(accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID, keepRoute)
|
||||
return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID, keepRoute)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
|
||||
}
|
||||
|
||||
// GetRoute mock implementation of GetRoute from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) {
|
||||
func (am *MockAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
|
||||
if am.GetRouteFunc != nil {
|
||||
return am.GetRouteFunc(accountID, routeID, userID)
|
||||
return am.GetRouteFunc(ctx, accountID, routeID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetRoute is not implemented")
|
||||
}
|
||||
|
||||
// SaveRoute mock implementation of SaveRoute from server.AccountManager interface
|
||||
func (am *MockAccountManager) SaveRoute(accountID string, userID string, route *route.Route) error {
|
||||
func (am *MockAccountManager) SaveRoute(ctx context.Context, accountID string, userID string, route *route.Route) error {
|
||||
if am.SaveRouteFunc != nil {
|
||||
return am.SaveRouteFunc(accountID, userID, route)
|
||||
return am.SaveRouteFunc(ctx, accountID, userID, route)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SaveRoute is not implemented")
|
||||
}
|
||||
|
||||
// DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error {
|
||||
func (am *MockAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
|
||||
if am.DeleteRouteFunc != nil {
|
||||
return am.DeleteRouteFunc(accountID, routeID, userID)
|
||||
return am.DeleteRouteFunc(ctx, accountID, routeID, userID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteRoute is not implemented")
|
||||
}
|
||||
|
||||
// ListRoutes mock implementation of ListRoutes from server.AccountManager interface
|
||||
func (am *MockAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) {
|
||||
func (am *MockAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
|
||||
if am.ListRoutesFunc != nil {
|
||||
return am.ListRoutesFunc(accountID, userID)
|
||||
return am.ListRoutesFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListRoutes is not implemented")
|
||||
}
|
||||
|
||||
// SaveSetupKey mocks SaveSetupKey of the AccountManager interface
|
||||
func (am *MockAccountManager) SaveSetupKey(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) {
|
||||
func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) {
|
||||
if am.SaveSetupKeyFunc != nil {
|
||||
return am.SaveSetupKeyFunc(accountID, key, userID)
|
||||
return am.SaveSetupKeyFunc(ctx, accountID, key, userID)
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SaveSetupKey is not implemented")
|
||||
}
|
||||
|
||||
// GetSetupKey mocks GetSetupKey of the AccountManager interface
|
||||
func (am *MockAccountManager) GetSetupKey(accountID, userID, keyID string) (*server.SetupKey, error) {
|
||||
func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) {
|
||||
if am.GetSetupKeyFunc != nil {
|
||||
return am.GetSetupKeyFunc(accountID, userID, keyID)
|
||||
return am.GetSetupKeyFunc(ctx, accountID, userID, keyID)
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetSetupKey is not implemented")
|
||||
}
|
||||
|
||||
// ListSetupKeys mocks ListSetupKeys of the AccountManager interface
|
||||
func (am *MockAccountManager) ListSetupKeys(accountID, userID string) ([]*server.SetupKey, error) {
|
||||
func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) {
|
||||
if am.ListSetupKeysFunc != nil {
|
||||
return am.ListSetupKeysFunc(accountID, userID)
|
||||
return am.ListSetupKeysFunc(ctx, accountID, userID)
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListSetupKeys is not implemented")
|
||||
}
|
||||
|
||||
// SaveUser mocks SaveUser of the AccountManager interface
|
||||
func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.User) (*server.UserInfo, error) {
|
||||
func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) {
|
||||
if am.SaveUserFunc != nil {
|
||||
return am.SaveUserFunc(accountID, userID, user)
|
||||
return am.SaveUserFunc(ctx, accountID, userID, user)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SaveUser is not implemented")
|
||||
}
|
||||
|
||||
// SaveOrAddUser mocks SaveOrAddUser of the AccountManager interface
|
||||
func (am *MockAccountManager) SaveOrAddUser(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) {
|
||||
func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) {
|
||||
if am.SaveOrAddUserFunc != nil {
|
||||
return am.SaveOrAddUserFunc(accountID, userID, user, addIfNotExists)
|
||||
return am.SaveOrAddUserFunc(ctx, accountID, userID, user, addIfNotExists)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser mocks DeleteUser of the AccountManager interface
|
||||
func (am *MockAccountManager) DeleteUser(accountID string, initiatorUserID string, targetUserID string) error {
|
||||
func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
|
||||
if am.DeleteUserFunc != nil {
|
||||
return am.DeleteUserFunc(accountID, initiatorUserID, targetUserID)
|
||||
return am.DeleteUserFunc(ctx, accountID, initiatorUserID, targetUserID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error {
|
||||
func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
|
||||
if am.InviteUserFunc != nil {
|
||||
return am.InviteUserFunc(accountID, initiatorUserID, targetUserID)
|
||||
return am.InviteUserFunc(ctx, accountID, initiatorUserID, targetUserID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method InviteUser is not implemented")
|
||||
}
|
||||
|
||||
// GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface
|
||||
func (am *MockAccountManager) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
func (am *MockAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
if am.GetNameServerGroupFunc != nil {
|
||||
return am.GetNameServerGroupFunc(accountID, userID, nsGroupID)
|
||||
return am.GetNameServerGroupFunc(ctx, accountID, userID, nsGroupID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface
|
||||
func (am *MockAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||
func (am *MockAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||
if am.CreateNameServerGroupFunc != nil {
|
||||
return am.CreateNameServerGroupFunc(accountID, name, description, nameServerList, groups, primary, domains, enabled, userID, searchDomainsEnabled)
|
||||
return am.CreateNameServerGroupFunc(ctx, accountID, name, description, nameServerList, groups, primary, domains, enabled, userID, searchDomainsEnabled)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// SaveNameServerGroup mocks SaveNameServerGroup of the AccountManager interface
|
||||
func (am *MockAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||
func (am *MockAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||
if am.SaveNameServerGroupFunc != nil {
|
||||
return am.SaveNameServerGroupFunc(accountID, userID, nsGroupToSave)
|
||||
return am.SaveNameServerGroupFunc(ctx, accountID, userID, nsGroupToSave)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteNameServerGroup mocks DeleteNameServerGroup of the AccountManager interface
|
||||
func (am *MockAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
|
||||
func (am *MockAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
|
||||
if am.DeleteNameServerGroupFunc != nil {
|
||||
return am.DeleteNameServerGroupFunc(accountID, nsGroupID, userID)
|
||||
return am.DeleteNameServerGroupFunc(ctx, accountID, nsGroupID, userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListNameServerGroups mocks ListNameServerGroups of the AccountManager interface
|
||||
func (am *MockAccountManager) ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
|
||||
func (am *MockAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
|
||||
if am.ListNameServerGroupsFunc != nil {
|
||||
return am.ListNameServerGroupsFunc(accountID, userID)
|
||||
return am.ListNameServerGroupsFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// CreateUser mocks CreateUser of the AccountManager interface
|
||||
func (am *MockAccountManager) CreateUser(accountID, userID string, invite *server.UserInfo) (*server.UserInfo, error) {
|
||||
func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID string, invite *server.UserInfo) (*server.UserInfo, error) {
|
||||
if am.CreateUserFunc != nil {
|
||||
return am.CreateUserFunc(accountID, userID, invite)
|
||||
return am.CreateUserFunc(ctx, accountID, userID, invite)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User,
|
||||
func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User,
|
||||
error,
|
||||
) {
|
||||
if am.GetAccountFromTokenFunc != nil {
|
||||
return am.GetAccountFromTokenFunc(claims)
|
||||
return am.GetAccountFromTokenFunc(ctx, claims)
|
||||
}
|
||||
return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error {
|
||||
func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||
if am.CheckUserAccessByJWTGroupsFunc != nil {
|
||||
return am.CheckUserAccessByJWTGroupsFunc(claims)
|
||||
return am.CheckUserAccessByJWTGroupsFunc(ctx, claims)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented")
|
||||
}
|
||||
|
||||
// GetPeers mocks GetPeers of the AccountManager interface
|
||||
func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
func (am *MockAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
if am.GetPeersFunc != nil {
|
||||
return am.GetPeersFunc(accountID, userID)
|
||||
return am.GetPeersFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeers is not implemented")
|
||||
}
|
||||
@ -595,57 +599,57 @@ func (am *MockAccountManager) GetDNSDomain() string {
|
||||
}
|
||||
|
||||
// GetEvents mocks GetEvents of the AccountManager interface
|
||||
func (am *MockAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) {
|
||||
func (am *MockAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
|
||||
if am.GetEventsFunc != nil {
|
||||
return am.GetEventsFunc(accountID, userID)
|
||||
return am.GetEventsFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetEvents is not implemented")
|
||||
}
|
||||
|
||||
// GetDNSSettings mocks GetDNSSettings of the AccountManager interface
|
||||
func (am *MockAccountManager) GetDNSSettings(accountID string, userID string) (*server.DNSSettings, error) {
|
||||
func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) {
|
||||
if am.GetDNSSettingsFunc != nil {
|
||||
return am.GetDNSSettingsFunc(accountID, userID)
|
||||
return am.GetDNSSettingsFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetDNSSettings is not implemented")
|
||||
}
|
||||
|
||||
// SaveDNSSettings mocks SaveDNSSettings of the AccountManager interface
|
||||
func (am *MockAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
|
||||
func (am *MockAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error {
|
||||
if am.SaveDNSSettingsFunc != nil {
|
||||
return am.SaveDNSSettingsFunc(accountID, userID, dnsSettingsToSave)
|
||||
return am.SaveDNSSettingsFunc(ctx, accountID, userID, dnsSettingsToSave)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SaveDNSSettings is not implemented")
|
||||
}
|
||||
|
||||
// GetPeer mocks GetPeer of the AccountManager interface
|
||||
func (am *MockAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
if am.GetPeerFunc != nil {
|
||||
return am.GetPeerFunc(accountID, peerID, userID)
|
||||
return am.GetPeerFunc(ctx, accountID, peerID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeer is not implemented")
|
||||
}
|
||||
|
||||
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdateAccountSettings(accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
||||
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
|
||||
if am.UpdateAccountSettingsFunc != nil {
|
||||
return am.UpdateAccountSettingsFunc(accountID, userID, newSettings)
|
||||
return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method UpdateAccountSettings is not implemented")
|
||||
}
|
||||
|
||||
// LoginPeer mocks LoginPeer of the AccountManager interface
|
||||
func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
if am.LoginPeerFunc != nil {
|
||||
return am.LoginPeerFunc(login)
|
||||
return am.LoginPeerFunc(ctx, login)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
|
||||
}
|
||||
|
||||
// SyncPeer mocks SyncPeer of the AccountManager interface
|
||||
func (am *MockAccountManager) SyncPeer(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
|
||||
if am.SyncPeerFunc != nil {
|
||||
return am.SyncPeerFunc(sync, account)
|
||||
return am.SyncPeerFunc(ctx, sync, account)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
|
||||
}
|
||||
@ -667,9 +671,9 @@ func (am *MockAccountManager) HasConnectedChannel(peerID string) bool {
|
||||
}
|
||||
|
||||
// StoreEvent mocks StoreEvent of the AccountManager interface
|
||||
func (am *MockAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
func (am *MockAccountManager) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
|
||||
if am.StoreEventFunc != nil {
|
||||
am.StoreEventFunc(initiatorID, targetID, accountID, activityID, meta)
|
||||
am.StoreEventFunc(ctx, initiatorID, targetID, accountID, activityID, meta)
|
||||
}
|
||||
}
|
||||
|
||||
@ -682,35 +686,35 @@ func (am *MockAccountManager) GetExternalCacheManager() server.ExternalCacheMana
|
||||
}
|
||||
|
||||
// GetPostureChecks mocks GetPostureChecks of the AccountManager interface
|
||||
func (am *MockAccountManager) GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||
if am.GetPostureChecksFunc != nil {
|
||||
return am.GetPostureChecksFunc(accountID, postureChecksID, userID)
|
||||
return am.GetPostureChecksFunc(ctx, accountID, postureChecksID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPostureChecks is not implemented")
|
||||
|
||||
}
|
||||
|
||||
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
|
||||
func (am *MockAccountManager) SavePostureChecks(accountID, userID string, postureChecks *posture.Checks) error {
|
||||
func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
|
||||
if am.SavePostureChecksFunc != nil {
|
||||
return am.SavePostureChecksFunc(accountID, userID, postureChecks)
|
||||
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
|
||||
}
|
||||
|
||||
// DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface
|
||||
func (am *MockAccountManager) DeletePostureChecks(accountID, postureChecksID, userID string) error {
|
||||
func (am *MockAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
|
||||
if am.DeletePostureChecksFunc != nil {
|
||||
return am.DeletePostureChecksFunc(accountID, postureChecksID, userID)
|
||||
return am.DeletePostureChecksFunc(ctx, accountID, postureChecksID, userID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeletePostureChecks is not implemented")
|
||||
|
||||
}
|
||||
|
||||
// ListPostureChecks mocks ListPostureChecks of the AccountManager interface
|
||||
func (am *MockAccountManager) ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) {
|
||||
func (am *MockAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
|
||||
if am.ListPostureChecksFunc != nil {
|
||||
return am.ListPostureChecksFunc(accountID, userID)
|
||||
return am.ListPostureChecksFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ListPostureChecks is not implemented")
|
||||
}
|
||||
@ -724,25 +728,25 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager {
|
||||
}
|
||||
|
||||
// UpdateIntegratedValidatorGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error {
|
||||
func (am *MockAccountManager) UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error {
|
||||
if am.UpdateIntegratedValidatorGroupsFunc != nil {
|
||||
return am.UpdateIntegratedValidatorGroupsFunc(accountID, userID, groups)
|
||||
return am.UpdateIntegratedValidatorGroupsFunc(ctx, accountID, userID, groups)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method UpdateIntegratedValidatorGroups is not implemented")
|
||||
}
|
||||
|
||||
// GroupValidation mocks GroupValidation of the AccountManager interface
|
||||
func (am *MockAccountManager) GroupValidation(accountId string, groups []string) (bool, error) {
|
||||
func (am *MockAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) {
|
||||
if am.GroupValidationFunc != nil {
|
||||
return am.GroupValidationFunc(accountId, groups)
|
||||
return am.GroupValidationFunc(ctx, accountId, groups)
|
||||
}
|
||||
return false, status.Errorf(codes.Unimplemented, "method GroupValidation is not implemented")
|
||||
}
|
||||
|
||||
// SyncPeerMeta mocks SyncPeerMeta of the AccountManager interface
|
||||
func (am *MockAccountManager) SyncPeerMeta(peerPubKey string, meta nbpeer.PeerSystemMeta) error {
|
||||
func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
|
||||
if am.SyncPeerMetaFunc != nil {
|
||||
return am.SyncPeerMetaFunc(peerPubKey, meta)
|
||||
return am.SyncPeerMetaFunc(ctx, peerPubKey, meta)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SyncPeerMeta is not implemented")
|
||||
}
|
||||
@ -754,3 +758,11 @@ func (am *MockAccountManager) FindExistingPostureCheck(accountID string, checks
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method FindExistingPostureCheck is not implemented")
|
||||
}
|
||||
|
||||
// GetAccountIDForPeerKey mocks GetAccountIDForPeerKey of the AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) {
|
||||
if am.GetAccountIDForPeerKeyFunc != nil {
|
||||
return am.GetAccountIDForPeerKeyFunc(ctx, peerKey)
|
||||
}
|
||||
return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented")
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"regexp"
|
||||
"unicode/utf8"
|
||||
@ -17,12 +18,12 @@ import (
|
||||
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
|
||||
|
||||
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
||||
func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -45,12 +46,12 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID
|
||||
}
|
||||
|
||||
// CreateNameServerGroup creates and saves a new nameserver group
|
||||
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||
func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -79,29 +80,29 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d
|
||||
account.NameServerGroups[newNSGroup.ID] = newNSGroup
|
||||
|
||||
account.Network.IncSerial()
|
||||
err = am.Store.SaveAccount(account)
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
am.StoreEvent(userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||
|
||||
return newNSGroup.Copy(), nil
|
||||
}
|
||||
|
||||
// SaveNameServerGroup saves nameserver group
|
||||
func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||
func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error {
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if nsGroupToSave == nil {
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group provided is nil")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -114,25 +115,25 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n
|
||||
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
|
||||
|
||||
account.Network.IncSerial()
|
||||
err = am.Store.SaveAccount(account)
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
am.StoreEvent(userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
||||
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
|
||||
func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error {
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -144,25 +145,25 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, use
|
||||
delete(account.NameServerGroups, nsGroupID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
err = am.Store.SaveAccount(account)
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
am.StoreEvent(userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListNameServerGroups returns a list of nameserver groups from account
|
||||
func (am *DefaultAccountManager) ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
|
||||
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
@ -383,6 +384,7 @@ func TestCreateNameServerGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
outNSGroup, err := am.CreateNameServerGroup(
|
||||
context.Background(),
|
||||
account.Id,
|
||||
testCase.inputArgs.name,
|
||||
testCase.inputArgs.description,
|
||||
@ -611,7 +613,7 @@ func TestSaveNameServerGroup(t *testing.T) {
|
||||
|
||||
account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Error("account should be saved")
|
||||
}
|
||||
@ -646,7 +648,7 @@ func TestSaveNameServerGroup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
err = am.SaveNameServerGroup(account.Id, userID, nsGroupToSave)
|
||||
err = am.SaveNameServerGroup(context.Background(), account.Id, userID, nsGroupToSave)
|
||||
|
||||
testCase.errFunc(t, err)
|
||||
|
||||
@ -654,7 +656,7 @@ func TestSaveNameServerGroup(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err = am.Store.GetAccount(account.Id)
|
||||
account, err = am.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -705,17 +707,17 @@ func TestDeleteNameServerGroup(t *testing.T) {
|
||||
|
||||
account.NameServerGroups[testingNSGroup.ID] = testingNSGroup
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Error("failed to save account")
|
||||
}
|
||||
|
||||
err = am.DeleteNameServerGroup(account.Id, testingNSGroup.ID, userID)
|
||||
err = am.DeleteNameServerGroup(context.Background(), account.Id, testingNSGroup.ID, userID)
|
||||
if err != nil {
|
||||
t.Error("deleting nameserver group failed with error: ", err)
|
||||
}
|
||||
|
||||
savedAccount, err := am.Store.GetAccount(account.Id)
|
||||
savedAccount, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Error("failed to retrieve saved account with error: ", err)
|
||||
}
|
||||
@ -738,7 +740,7 @@ func TestGetNameServerGroup(t *testing.T) {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
foundGroup, err := am.GetNameServerGroup(account.Id, testUserID, existingNSGroupID)
|
||||
foundGroup, err := am.GetNameServerGroup(context.Background(), account.Id, testUserID, existingNSGroupID)
|
||||
if err != nil {
|
||||
t.Error("getting existing nameserver group failed with error: ", err)
|
||||
}
|
||||
@ -747,7 +749,7 @@ func TestGetNameServerGroup(t *testing.T) {
|
||||
t.Error("got a nil group while getting nameserver group with ID")
|
||||
}
|
||||
|
||||
_, err = am.GetNameServerGroup(account.Id, testUserID, "not existing")
|
||||
_, err = am.GetNameServerGroup(context.Background(), account.Id, testUserID, "not existing")
|
||||
if err == nil {
|
||||
t.Error("getting not existing nameserver group should return error, got nil")
|
||||
}
|
||||
@ -760,13 +762,13 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
|
||||
}
|
||||
|
||||
func createNSStore(t *testing.T) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, cleanUp, err := NewTestStoreFromJson(dataDir)
|
||||
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -829,7 +831,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
|
||||
userID := testUserID
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(accountID, userID, domain)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain)
|
||||
|
||||
account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup
|
||||
|
||||
@ -846,16 +848,16 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
|
||||
account.Groups[newGroup1.ID] = newGroup1
|
||||
account.Groups[newGroup2.ID] = newGroup2
|
||||
|
||||
err := am.Store.SaveAccount(account)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, _, _, err = am.AddPeer("", userID, peer1)
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, _, _, err = am.AddPeer("", userID, peer2)
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
@ -45,8 +46,8 @@ type PeerLogin struct {
|
||||
|
||||
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
|
||||
// the current user is not an admin.
|
||||
func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -79,7 +80,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P
|
||||
|
||||
// fetch all the peers that have access to the user's peers
|
||||
for _, peer := range peers {
|
||||
aclPeers, _ := account.getPeerConnectionResources(peer.ID, approvedPeersMap)
|
||||
aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
|
||||
for _, p := range aclPeers {
|
||||
peersMap[p.ID] = p
|
||||
}
|
||||
@ -94,7 +95,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P
|
||||
}
|
||||
|
||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP, account *Account) error {
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error {
|
||||
peer, err := account.FindPeerByPubKey(peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -113,7 +114,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
|
||||
if am.geo != nil && realIP != nil {
|
||||
location, err := am.geo.Lookup(realIP)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
} else {
|
||||
peer.Location.ConnectionIP = realIP
|
||||
peer.Location.CountryCode = location.Country.ISOCode
|
||||
@ -121,7 +122,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
|
||||
peer.Location.GeoNameID = location.City.GeonameID
|
||||
err = am.Store.SavePeerLocation(account.Id, peer)
|
||||
if err != nil {
|
||||
log.Warnf("could not store location for peer %s: %s", peer.ID, err)
|
||||
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -134,24 +135,24 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
|
||||
}
|
||||
|
||||
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
||||
am.checkAndSchedulePeerLoginExpiration(account)
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||
}
|
||||
|
||||
if oldStatus.LoginExpired {
|
||||
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
||||
// the expired one. Here we notify them that connection is now allowed again.
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
|
||||
func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -161,7 +162,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb
|
||||
return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID)
|
||||
}
|
||||
|
||||
update, err = am.integratedPeerValidator.ValidatePeer(update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
update, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -172,7 +173,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb
|
||||
if !update.SSHEnabled {
|
||||
event = activity.PeerSSHDisabled
|
||||
}
|
||||
am.StoreEvent(userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||
}
|
||||
|
||||
if peer.Name != update.Name {
|
||||
@ -187,7 +188,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb
|
||||
|
||||
peer.DNSLabel = newLabel
|
||||
|
||||
am.StoreEvent(userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain()))
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain()))
|
||||
}
|
||||
|
||||
if peer.LoginExpirationEnabled != update.LoginExpirationEnabled {
|
||||
@ -202,27 +203,27 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb
|
||||
if !update.LoginExpirationEnabled {
|
||||
event = activity.PeerLoginExpirationDisabled
|
||||
}
|
||||
am.StoreEvent(userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||
|
||||
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
||||
am.checkAndSchedulePeerLoginExpiration(account)
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||
}
|
||||
}
|
||||
|
||||
account.UpdatePeer(peer)
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
// deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock
|
||||
func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string, userID string) error {
|
||||
func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Account, peerIDs []string, userID string) error {
|
||||
|
||||
// the first loop is needed to ensure all peers present under the account before modifying, otherwise
|
||||
// we might have some inconsistencies
|
||||
@ -239,13 +240,13 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string,
|
||||
// the 2nd loop performs the actual modification
|
||||
for _, peer := range peers {
|
||||
|
||||
err := am.integratedPeerValidator.PeerDeleted(account.Id, peer.ID)
|
||||
err := am.integratedPeerValidator.PeerDeleted(ctx, account.Id, peer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account.DeletePeer(peer.ID)
|
||||
am.peersUpdateManager.SendUpdate(peer.ID,
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID,
|
||||
&UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
// fill those field for backward compatibility
|
||||
@ -261,41 +262,41 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string,
|
||||
},
|
||||
},
|
||||
})
|
||||
am.peersUpdateManager.CloseChannel(peer.ID)
|
||||
am.StoreEvent(userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
|
||||
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePeer removes peer from the account by its IP
|
||||
func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.deletePeers(account, []string{peerID}, userID)
|
||||
err = am.deletePeers(ctx, account, []string{peerID}, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
||||
func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, error) {
|
||||
account, err := am.Store.GetAccountByPeerID(peerID)
|
||||
func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) {
|
||||
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -314,12 +315,12 @@ func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, erro
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validatedPeers), nil
|
||||
return account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validatedPeers), nil
|
||||
}
|
||||
|
||||
// GetPeerNetwork returns the Network for a given peer
|
||||
func (am *DefaultAccountManager) GetPeerNetwork(peerID string) (*Network, error) {
|
||||
account, err := am.Store.GetAccountByPeerID(peerID)
|
||||
func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID string) (*Network, error) {
|
||||
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -334,7 +335,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(peerID string) (*Network, error)
|
||||
// to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied
|
||||
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
|
||||
// The peer property is just a placeholder for the Peer properties to pass further
|
||||
func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
if setupKey == "" && userID == "" {
|
||||
// no auth method provided => reject access
|
||||
return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
|
||||
@ -348,13 +349,13 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
||||
addedByUser = true
|
||||
accountID, err = am.Store.GetAccountIDByUserID(userID)
|
||||
} else {
|
||||
accountID, err = am.Store.GetAccountIDBySetupKey(setupKey)
|
||||
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, setupKey)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
@ -363,14 +364,14 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
||||
|
||||
var account *Account
|
||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||
account, err = am.Store.GetAccount(accountID)
|
||||
account, err = am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
|
||||
if am.idpManager != nil {
|
||||
userdata, err := am.lookupUserInCache(userID, account)
|
||||
userdata, err := am.lookupUserInCache(ctx, userID, account)
|
||||
if err == nil && userdata != nil {
|
||||
peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
|
||||
}
|
||||
@ -479,7 +480,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
||||
}
|
||||
}
|
||||
|
||||
newPeer = am.integratedPeerValidator.PreparePeer(account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra)
|
||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra)
|
||||
|
||||
if addedByUser {
|
||||
user, err := account.FindUser(userID)
|
||||
@ -491,7 +492,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
||||
|
||||
account.Peers[newPeer.ID] = newPeer
|
||||
account.Network.IncSerial()
|
||||
err = am.Store.SaveAccount(account)
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@ -506,9 +507,9 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
||||
opEvent.Meta["setup_key_name"] = setupKeyName
|
||||
}
|
||||
|
||||
am.StoreEvent(opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
@ -516,12 +517,12 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
||||
}
|
||||
|
||||
postureChecks := am.getPeerPostureChecks(account, peer)
|
||||
networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain, approvedPeersMap)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, am.dnsDomain, approvedPeersMap)
|
||||
return newPeer, networkMap, postureChecks, nil
|
||||
}
|
||||
|
||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||
func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.NewPeerNotRegisteredError()
|
||||
@ -532,23 +533,23 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if peerLoginExpired(peer, account.Settings) {
|
||||
if peerLoginExpired(ctx, peer, account.Settings) {
|
||||
return nil, nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
||||
}
|
||||
|
||||
peer, updated := updatePeerMeta(peer, sync.Meta, account)
|
||||
if updated {
|
||||
err = am.Store.SaveAccount(account)
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if sync.UpdateAccountPeers {
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
}
|
||||
|
||||
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@ -563,7 +564,7 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp
|
||||
}
|
||||
|
||||
if isStatusChanged {
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
validPeersMap, err := am.GetValidatedPeers(account)
|
||||
@ -572,13 +573,13 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbp
|
||||
}
|
||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
||||
|
||||
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validPeersMap), postureChecks, nil
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validPeersMap), postureChecks, nil
|
||||
}
|
||||
|
||||
// LoginPeer logs in or registers a peer.
|
||||
// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so.
|
||||
func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(login.WireGuardPubKey)
|
||||
func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
|
||||
// we couldn't find this peer by its public key which can mean that peer hasn't been registered yet.
|
||||
@ -591,7 +592,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
|
||||
if am.geo != nil && login.ConnectionIP != nil {
|
||||
location, err := am.geo.Lookup(login.ConnectionIP)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get location for new peer realip: [%s]: %v", login.ConnectionIP.String(), err)
|
||||
log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", login.ConnectionIP.String(), err)
|
||||
} else {
|
||||
newPeer.Location.ConnectionIP = login.ConnectionIP
|
||||
newPeer.Location.CountryCode = location.Country.ISOCode
|
||||
@ -601,19 +602,19 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
|
||||
}
|
||||
}
|
||||
|
||||
return am.AddPeer(login.SetupKey, login.UserID, newPeer)
|
||||
return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer)
|
||||
}
|
||||
|
||||
log.Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err)
|
||||
log.WithContext(ctx).Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err)
|
||||
return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer")
|
||||
}
|
||||
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(login.WireGuardPubKey)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.NewPeerNotRegisteredError()
|
||||
}
|
||||
|
||||
accSettings, err := am.Store.GetAccountSettings(accountID)
|
||||
accSettings, err := am.Store.GetAccountSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.Errorf(status.Internal, "failed to get account settings: %s", err)
|
||||
}
|
||||
@ -621,30 +622,30 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
|
||||
var isWriteLock bool
|
||||
|
||||
// duplicated logic from after the lock to have an early exit
|
||||
expired := peerLoginExpired(peer, accSettings)
|
||||
expired := peerLoginExpired(ctx, peer, accSettings)
|
||||
switch {
|
||||
case expired:
|
||||
if err := checkAuth(login.UserID, peer); err != nil {
|
||||
if err := checkAuth(ctx, login.UserID, peer); err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
isWriteLock = true
|
||||
log.Debugf("peer login expired, acquiring write lock")
|
||||
log.WithContext(ctx).Debugf("peer login expired, acquiring write lock")
|
||||
|
||||
case peer.UpdateMetaIfNew(login.Meta):
|
||||
isWriteLock = true
|
||||
log.Debugf("peer changed meta, acquiring write lock")
|
||||
log.WithContext(ctx).Debugf("peer changed meta, acquiring write lock")
|
||||
|
||||
default:
|
||||
isWriteLock = false
|
||||
log.Debugf("peer meta is the same, acquiring read lock")
|
||||
log.WithContext(ctx).Debugf("peer meta is the same, acquiring read lock")
|
||||
}
|
||||
|
||||
var unlock func()
|
||||
|
||||
if isWriteLock {
|
||||
unlock = am.Store.AcquireAccountWriteLock(accountID)
|
||||
unlock = am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
} else {
|
||||
unlock = am.Store.AcquireAccountReadLock(accountID)
|
||||
unlock = am.Store.AcquireAccountReadLock(ctx, accountID)
|
||||
}
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
@ -653,7 +654,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
|
||||
}()
|
||||
|
||||
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@ -671,8 +672,8 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
|
||||
// this flag prevents unnecessary calls to the persistent store.
|
||||
shouldStoreAccount := false
|
||||
updateRemotePeers := false
|
||||
if peerLoginExpired(peer, account.Settings) {
|
||||
err = checkAuth(login.UserID, peer)
|
||||
if peerLoginExpired(ctx, peer, account.Settings) {
|
||||
err = checkAuth(ctx, login.UserID, peer)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@ -689,10 +690,10 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
|
||||
}
|
||||
user.updateLastLogin(peer.LastLogin)
|
||||
|
||||
am.StoreEvent(login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
|
||||
am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
|
||||
}
|
||||
|
||||
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@ -701,17 +702,17 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
|
||||
shouldStoreAccount = true
|
||||
}
|
||||
|
||||
peer, err = am.checkAndUpdatePeerSSHKey(peer, account, login.SSHKey)
|
||||
peer, err = am.checkAndUpdatePeerSSHKey(ctx, peer, account, login.SSHKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if shouldStoreAccount {
|
||||
if !isWriteLock {
|
||||
log.Errorf("account %s should be stored but is not write locked", accountID)
|
||||
log.WithContext(ctx).Errorf("account %s should be stored but is not write locked", accountID)
|
||||
return nil, nil, nil, status.Errorf(status.Internal, "account should be stored but is not write locked")
|
||||
}
|
||||
err = am.Store.SaveAccount(account)
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@ -720,7 +721,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
|
||||
unlock = nil
|
||||
|
||||
if updateRemotePeers || isStatusChanged {
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
var postureChecks []*posture.Checks
|
||||
@ -738,7 +739,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw
|
||||
}
|
||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
||||
|
||||
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil
|
||||
}
|
||||
|
||||
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
|
||||
@ -754,23 +755,23 @@ func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkAuth(loginUserID string, peer *nbpeer.Peer) error {
|
||||
func checkAuth(ctx context.Context, loginUserID string, peer *nbpeer.Peer) error {
|
||||
if loginUserID == "" {
|
||||
// absence of a user ID indicates that JWT wasn't provided.
|
||||
return status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
|
||||
}
|
||||
if peer.UserID != loginUserID {
|
||||
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID)
|
||||
log.WithContext(ctx).Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID)
|
||||
return status.Errorf(status.Unauthenticated, "can't login")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func peerLoginExpired(peer *nbpeer.Peer, settings *Settings) bool {
|
||||
func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings) bool {
|
||||
expired, expiresIn := peer.LoginExpired(settings.PeerLoginExpiration)
|
||||
expired = settings.PeerLoginExpirationEnabled && expired
|
||||
if expired || peer.Status.LoginExpired {
|
||||
log.Debugf("peer's %s login expired %v ago", peer.ID, expiresIn)
|
||||
log.WithContext(ctx).Debugf("peer's %s login expired %v ago", peer.ID, expiresIn)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@ -781,48 +782,48 @@ func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) {
|
||||
account.UpdatePeer(peer)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(peer *nbpeer.Peer, account *Account, newSSHKey string) (*nbpeer.Peer, error) {
|
||||
func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(ctx context.Context, peer *nbpeer.Peer, account *Account, newSSHKey string) (*nbpeer.Peer, error) {
|
||||
if len(newSSHKey) == 0 {
|
||||
log.Debugf("no new SSH key provided for peer %s, skipping update", peer.ID)
|
||||
log.WithContext(ctx).Debugf("no new SSH key provided for peer %s, skipping update", peer.ID)
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
if peer.SSHKey == newSSHKey {
|
||||
log.Debugf("same SSH key provided for peer %s, skipping update", peer.ID)
|
||||
log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peer.ID)
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
peer.SSHKey = newSSHKey
|
||||
account.UpdatePeer(peer)
|
||||
|
||||
err := am.Store.SaveAccount(account)
|
||||
err := am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// trigger network map update
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
// UpdatePeerSSHKey updates peer's public SSH key
|
||||
func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string) error {
|
||||
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
|
||||
if sshKey == "" {
|
||||
log.Debugf("empty SSH key provided for peer %s, skipping update", peerID)
|
||||
log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccountByPeerID(peerID)
|
||||
account, err := am.Store.GetAccountByPeerID(ctx, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountWriteLock(account.Id)
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, account.Id)
|
||||
defer unlock()
|
||||
|
||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||
account, err = am.Store.GetAccount(account.Id)
|
||||
account, err = am.Store.GetAccount(ctx, account.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -833,30 +834,30 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string)
|
||||
}
|
||||
|
||||
if peer.SSHKey == sshKey {
|
||||
log.Debugf("same SSH key provided for peer %s, skipping update", peerID)
|
||||
log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID)
|
||||
return nil
|
||||
}
|
||||
|
||||
peer.SSHKey = sshKey
|
||||
account.UpdatePeer(peer)
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// trigger network map update
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeer for a given accountID, peerID and userID error if not found.
|
||||
func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -893,7 +894,7 @@ func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbp
|
||||
}
|
||||
|
||||
for _, p := range userPeers {
|
||||
aclPeers, _ := account.getPeerConnectionResources(p.ID, approvedPeersMap)
|
||||
aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap)
|
||||
for _, aclPeer := range aclPeers {
|
||||
if aclPeer.ID == peerID {
|
||||
return peer, nil
|
||||
@ -914,23 +915,23 @@ func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Acco
|
||||
|
||||
// updateAccountPeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) updateAccountPeers(account *Account) {
|
||||
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
|
||||
peers := account.GetPeers()
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
log.Errorf("failed send out updates to peers, failed to validate peer: %v", err)
|
||||
log.WithContext(ctx).Errorf("failed send out updates to peers, failed to validate peer: %v", err)
|
||||
return
|
||||
}
|
||||
for _, peer := range peers {
|
||||
if !am.peersUpdateManager.HasChannel(peer.ID) {
|
||||
log.Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
postureChecks := am.getPeerPostureChecks(account, peer)
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap)
|
||||
update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks)
|
||||
am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update})
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap)
|
||||
update := toSyncResponse(ctx, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks)
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -80,7 +81,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false)
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
@ -92,7 +93,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
peer1, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
|
||||
})
|
||||
@ -106,7 +107,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
_, _, _, err = manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
|
||||
_, _, _, err = manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||
})
|
||||
@ -116,7 +117,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
networkMap, err := manager.GetNetworkMap(peer1.ID)
|
||||
networkMap, err := manager.GetNetworkMap(context.Background(), peer1.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -165,7 +166,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
peer1, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
|
||||
})
|
||||
@ -179,7 +180,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
peer2, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
|
||||
peer2, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||
})
|
||||
@ -188,13 +189,13 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
policies, err := manager.ListPolicies(account.Id, userID)
|
||||
policies, err := manager.ListPolicies(context.Background(), account.Id, userID)
|
||||
if err != nil {
|
||||
t.Errorf("expecting to get a list of rules, got failure %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = manager.DeletePolicy(account.Id, policies[0].ID, userID)
|
||||
err = manager.DeletePolicy(context.Background(), account.Id, policies[0].ID, userID)
|
||||
if err != nil {
|
||||
t.Errorf("expecting to delete 1 group, got failure %v", err)
|
||||
return
|
||||
@ -213,12 +214,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
group1.Peers = append(group1.Peers, peer1.ID)
|
||||
group2.Peers = append(group2.Peers, peer2.ID)
|
||||
|
||||
err = manager.SaveGroup(account.Id, userID, &group1)
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &group1)
|
||||
if err != nil {
|
||||
t.Errorf("expecting group1 to be added, got failure %v", err)
|
||||
return
|
||||
}
|
||||
err = manager.SaveGroup(account.Id, userID, &group2)
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &group2)
|
||||
if err != nil {
|
||||
t.Errorf("expecting group2 to be added, got failure %v", err)
|
||||
return
|
||||
@ -235,13 +236,13 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
}
|
||||
err = manager.SavePolicy(account.Id, userID, &policy)
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
|
||||
if err != nil {
|
||||
t.Errorf("expecting rule to be added, got failure %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
networkMap1, err := manager.GetNetworkMap(peer1.ID)
|
||||
networkMap1, err := manager.GetNetworkMap(context.Background(), peer1.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -264,7 +265,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
)
|
||||
}
|
||||
|
||||
networkMap2, err := manager.GetNetworkMap(peer2.ID)
|
||||
networkMap2, err := manager.GetNetworkMap(context.Background(), peer2.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -283,13 +284,13 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
}
|
||||
|
||||
policy.Enabled = false
|
||||
err = manager.SavePolicy(account.Id, userID, &policy)
|
||||
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
|
||||
if err != nil {
|
||||
t.Errorf("expecting rule to be added, got failure %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
networkMap1, err = manager.GetNetworkMap(peer1.ID)
|
||||
networkMap1, err = manager.GetNetworkMap(context.Background(), peer1.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -304,7 +305,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
networkMap2, err = manager.GetNetworkMap(peer2.ID)
|
||||
networkMap2, err = manager.GetNetworkMap(context.Background(), peer2.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -329,7 +330,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false)
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
@ -341,7 +342,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
peer1, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
|
||||
})
|
||||
@ -355,7 +356,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
_, _, _, err = manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
|
||||
_, _, _, err = manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||
})
|
||||
@ -365,7 +366,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
network, err := manager.GetPeerNetwork(peer1.ID)
|
||||
network, err := manager.GetPeerNetwork(context.Background(), peer1.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -387,21 +388,21 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(accountID, adminUser, "")
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account.Users[someUser] = &User{
|
||||
Id: someUser,
|
||||
Role: UserRoleUser,
|
||||
}
|
||||
account.Settings.RegularUsersViewBlocked = false
|
||||
|
||||
err = manager.Store.SaveAccount(account)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
// two peers one added by a regular user and one with a setup key
|
||||
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false)
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
@ -413,7 +414,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
peer1, _, _, err := manager.AddPeer("", someUser, &nbpeer.Peer{
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), "", someUser, &nbpeer.Peer{
|
||||
Key: peerKey1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||
})
|
||||
@ -429,7 +430,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
}
|
||||
|
||||
// the second peer added with a setup key
|
||||
peer2, _, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
|
||||
peer2, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||
})
|
||||
@ -439,7 +440,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
}
|
||||
|
||||
// the user can see its own peer
|
||||
peer, err := manager.GetPeer(accountID, peer1.ID, someUser)
|
||||
peer, err := manager.GetPeer(context.Background(), accountID, peer1.ID, someUser)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -447,7 +448,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
assert.NotNil(t, peer)
|
||||
|
||||
// the user can see peer2 because peer1 of the user has access to peer2 due to the All group and the default rule 0 all-to-all access
|
||||
peer, err = manager.GetPeer(accountID, peer2.ID, someUser)
|
||||
peer, err = manager.GetPeer(context.Background(), accountID, peer2.ID, someUser)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -456,7 +457,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
|
||||
// delete the all-to-all policy so that user's peer1 has no access to peer2
|
||||
for _, policy := range account.Policies {
|
||||
err = manager.DeletePolicy(accountID, policy.ID, adminUser)
|
||||
err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -464,18 +465,18 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
}
|
||||
|
||||
// at this point the user can't see the details of peer2
|
||||
peer, err = manager.GetPeer(accountID, peer2.ID, someUser) //nolint
|
||||
peer, err = manager.GetPeer(context.Background(), accountID, peer2.ID, someUser) //nolint
|
||||
assert.Error(t, err)
|
||||
|
||||
// admin users can always access all the peers
|
||||
peer, err = manager.GetPeer(accountID, peer1.ID, adminUser)
|
||||
peer, err = manager.GetPeer(context.Background(), accountID, peer1.ID, adminUser)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
assert.NotNil(t, peer)
|
||||
|
||||
peer, err = manager.GetPeer(accountID, peer2.ID, adminUser)
|
||||
peer, err = manager.GetPeer(context.Background(), accountID, peer2.ID, adminUser)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -574,7 +575,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(accountID, adminUser, "")
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account.Users[someUser] = &User{
|
||||
Id: someUser,
|
||||
Role: testCase.role,
|
||||
@ -583,7 +584,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
account.Policies = []*Policy{}
|
||||
account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings
|
||||
|
||||
err = manager.Store.SaveAccount(account)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@ -601,7 +602,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
_, _, _, err = manager.AddPeer("", someUser, &nbpeer.Peer{
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", someUser, &nbpeer.Peer{
|
||||
Key: peerKey1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
|
||||
})
|
||||
@ -610,7 +611,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
_, _, _, err = manager.AddPeer("", adminUser, &nbpeer.Peer{
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", adminUser, &nbpeer.Peer{
|
||||
Key: peerKey2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||
})
|
||||
@ -619,7 +620,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
peers, err := manager.GetPeers(accountID, someUser)
|
||||
peers, err := manager.GetPeers(context.Background(), accountID, someUser)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -211,9 +212,9 @@ type FirewallRule struct {
|
||||
// getPeerConnectionResources for a given peer
|
||||
//
|
||||
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||
func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
|
||||
generateResources, getAccumulatedResources := a.connResourcesGenerator()
|
||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
|
||||
for _, policy := range a.Policies {
|
||||
if !policy.Enabled {
|
||||
continue
|
||||
@ -224,8 +225,8 @@ func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap ma
|
||||
continue
|
||||
}
|
||||
|
||||
sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
|
||||
destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID, nil, validatedPeersMap)
|
||||
sourcePeers, peerInSources := getAllPeersFromGroups(ctx, a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
|
||||
destinationPeers, peerInDestinations := getAllPeersFromGroups(ctx, a, rule.Destinations, peerID, nil, validatedPeersMap)
|
||||
|
||||
if rule.Bidirectional {
|
||||
if peerInSources {
|
||||
@ -254,7 +255,7 @@ func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap ma
|
||||
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
|
||||
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be
|
||||
// generated. The accumulator function returns the result of all the generator calls.
|
||||
func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
|
||||
func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
|
||||
rulesExists := make(map[string]struct{})
|
||||
peersExists := make(map[string]struct{})
|
||||
rules := make([]*FirewallRule, 0)
|
||||
@ -262,7 +263,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in
|
||||
|
||||
all, err := a.GetGroupAll()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get group all: %v", err)
|
||||
log.WithContext(ctx).Errorf("failed to get group all: %v", err)
|
||||
all = &nbgroup.Group{}
|
||||
}
|
||||
|
||||
@ -313,11 +314,11 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in
|
||||
}
|
||||
|
||||
// GetPolicy from the store
|
||||
func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (*Policy, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -341,11 +342,11 @@ func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (
|
||||
}
|
||||
|
||||
// SavePolicy in the store
|
||||
func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Policy) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -353,7 +354,7 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po
|
||||
exists := am.savePolicy(account, policy)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -361,19 +362,19 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po
|
||||
if exists {
|
||||
action = activity.PolicyUpdated
|
||||
}
|
||||
am.StoreEvent(userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePolicy from the store
|
||||
func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -384,23 +385,23 @@ func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
|
||||
am.updateAccountPeers(account)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPolicies from the store
|
||||
func (am *DefaultAccountManager) ListPolicies(accountID, userID string) ([]*Policy, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(accountID)
|
||||
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
|
||||
unlock := am.Store.AcquireAccountWriteLock(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -490,7 +491,7 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
|
||||
//
|
||||
// Important: Posture checks are applicable only to source group peers,
|
||||
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
|
||||
func getAllPeersFromGroups(account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
|
||||
func getAllPeersFromGroups(ctx context.Context, account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
|
||||
peerInGroups := false
|
||||
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
|
||||
for _, g := range groups {
|
||||
@ -506,7 +507,7 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string, sou
|
||||
}
|
||||
|
||||
// validate the peer based on policy posture checks applied
|
||||
isValid := account.validatePostureChecksOnPeer(sourcePostureChecksIDs, peer.ID)
|
||||
isValid := account.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
|
||||
if !isValid {
|
||||
continue
|
||||
}
|
||||
@ -527,7 +528,7 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string, sou
|
||||
}
|
||||
|
||||
// validatePostureChecksOnPeer validates the posture checks on a peer
|
||||
func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, peerID string) bool {
|
||||
func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool {
|
||||
peer, ok := a.Peers[peerID]
|
||||
if !ok && peer == nil {
|
||||
return false
|
||||
@ -540,9 +541,9 @@ func (a *Account) validatePostureChecksOnPeer(sourcePostureChecksID []string, pe
|
||||
}
|
||||
|
||||
for _, check := range postureChecks.GetChecks() {
|
||||
isValid, err := check.Check(*peer)
|
||||
isValid, err := check.Check(ctx, *peer)
|
||||
if err != nil {
|
||||
log.Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error())
|
||||
log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error())
|
||||
}
|
||||
if !isValid {
|
||||
return false
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
@ -143,14 +144,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
|
||||
t.Run("check that all peers get map", func(t *testing.T) {
|
||||
for _, p := range account.Peers {
|
||||
peers, firewallRules := account.getPeerConnectionResources(p.ID, validatedPeers)
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), p.ID, validatedPeers)
|
||||
assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present")
|
||||
assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check first peer map details", func(t *testing.T) {
|
||||
peers, firewallRules := account.getPeerConnectionResources("peerB", validatedPeers)
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", validatedPeers)
|
||||
assert.Len(t, peers, 7)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
@ -386,7 +387,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("check first peer map", func(t *testing.T) {
|
||||
peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers)
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
|
||||
epectedFirewallRules := []*FirewallRule{
|
||||
@ -414,7 +415,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("check second peer map", func(t *testing.T) {
|
||||
peers, firewallRules := account.getPeerConnectionResources("peerC", approvedPeers)
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerB"])
|
||||
|
||||
epectedFirewallRules := []*FirewallRule{
|
||||
@ -444,7 +445,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
account.Policies[1].Rules[0].Bidirectional = false
|
||||
|
||||
t.Run("check first peer map directional only", func(t *testing.T) {
|
||||
peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers)
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
|
||||
epectedFirewallRules := []*FirewallRule{
|
||||
@ -465,7 +466,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("check second peer map directional only", func(t *testing.T) {
|
||||
peers, firewallRules := account.getPeerConnectionResources("peerC", approvedPeers)
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
assert.Contains(t, peers, account.Peers["peerB"])
|
||||
|
||||
epectedFirewallRules := []*FirewallRule{
|
||||
@ -662,7 +663,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
|
||||
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
|
||||
// will establish a connection with all source peers satisfying the NB posture check.
|
||||
peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers)
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@ -672,7 +673,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||
// We expect a single permissive firewall rule which all outgoing connections
|
||||
peers, firewallRules = account.getPeerConnectionResources("peerC", approvedPeers)
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||
assert.Len(t, firewallRules, 1)
|
||||
expectedFirewallRules := []*FirewallRule{
|
||||
@ -688,7 +689,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.getPeerConnectionResources("peerE", approvedPeers)
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers)
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@ -698,7 +699,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.getPeerConnectionResources("peerI", approvedPeers)
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers)
|
||||
assert.Len(t, peers, 4)
|
||||
assert.Len(t, firewallRules, 4)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
@ -713,19 +714,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
|
||||
// no connection should be established to any peer of destination group
|
||||
peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers)
|
||||
peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers)
|
||||
assert.Len(t, peers, 0)
|
||||
assert.Len(t, firewallRules, 0)
|
||||
|
||||
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
|
||||
// no connection should be established to any peer of destination group
|
||||
peers, firewallRules = account.getPeerConnectionResources("peerI", approvedPeers)
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers)
|
||||
assert.Len(t, peers, 0)
|
||||
assert.Len(t, firewallRules, 0)
|
||||
|
||||
// peerC satisfy the NB posture check, should establish connection to all destination group peer's
|
||||
// We expect a single permissive firewall rule which all outgoing connections
|
||||
peers, firewallRules = account.getPeerConnectionResources("peerC", approvedPeers)
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers)
|
||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
|
||||
|
||||
@ -740,14 +741,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
|
||||
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
|
||||
// all source group peers satisfying the NB posture check should establish connection
|
||||
peers, firewallRules = account.getPeerConnectionResources("peerE", approvedPeers)
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers)
|
||||
assert.Len(t, peers, 3)
|
||||
assert.Len(t, firewallRules, 3)
|
||||
assert.Contains(t, peers, account.Peers["peerA"])
|
||||
assert.Contains(t, peers, account.Peers["peerC"])
|
||||
assert.Contains(t, peers, account.Peers["peerD"])
|
||||
|
||||
peers, firewallRules = account.getPeerConnectionResources("peerA", approvedPeers)
|
||||
peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerA", approvedPeers)
|
||||
assert.Len(t, peers, 5)
|
||||
// assert peers from Group Swarm
|
||||
assert.Contains(t, peers, account.Peers["peerD"])
|
||||
|
@ -1,6 +1,7 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
@ -31,7 +32,7 @@ var (
|
||||
// Check represents an interface for performing a check on a peer.
|
||||
type Check interface {
|
||||
Name() string
|
||||
Check(peer nbpeer.Peer) (bool, error)
|
||||
Check(ctx context.Context, peer nbpeer.Peer) (bool, error)
|
||||
Validate() error
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
@ -25,7 +26,7 @@ type GeoLocationCheck struct {
|
||||
Action string
|
||||
}
|
||||
|
||||
func (g *GeoLocationCheck) Check(peer nbpeer.Peer) (bool, error) {
|
||||
func (g *GeoLocationCheck) Check(_ context.Context, peer nbpeer.Peer) (bool, error) {
|
||||
// deny if the peer location is not evaluated
|
||||
if peer.Location.CountryCode == "" && peer.Location.CityName == "" {
|
||||
return false, fmt.Errorf("peer's location is not set")
|
||||
|
@ -1,6 +1,7 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
@ -226,7 +227,7 @@ func TestGeoLocationCheck_Check(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid, err := tt.check.Check(tt.input)
|
||||
isValid, err := tt.check.Check(context.Background(), tt.input)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
@ -15,7 +16,7 @@ type NBVersionCheck struct {
|
||||
|
||||
var _ Check = (*NBVersionCheck)(nil)
|
||||
|
||||
func (n *NBVersionCheck) Check(peer nbpeer.Peer) (bool, error) {
|
||||
func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) {
|
||||
peerNBVersion, err := version.NewVersion(peer.Meta.WtVersion)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@ -30,7 +31,7 @@ func (n *NBVersionCheck) Check(peer nbpeer.Peer) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
log.Debugf("peer %s NB version %s is older than minimum allowed version %s",
|
||||
log.WithContext(ctx).Debugf("peer %s NB version %s is older than minimum allowed version %s",
|
||||
peer.ID, peer.Meta.WtVersion, n.MinVersion)
|
||||
|
||||
return false, nil
|
||||
|
@ -1,6 +1,7 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
@ -98,7 +99,7 @@ func TestNBVersionCheck_Check(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid, err := tt.check.Check(tt.input)
|
||||
isValid, err := tt.check.Check(context.Background(), tt.input)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
@ -16,7 +17,7 @@ type PeerNetworkRangeCheck struct {
|
||||
|
||||
var _ Check = (*PeerNetworkRangeCheck)(nil)
|
||||
|
||||
func (p *PeerNetworkRangeCheck) Check(peer nbpeer.Peer) (bool, error) {
|
||||
func (p *PeerNetworkRangeCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) {
|
||||
if len(peer.Meta.NetworkAddresses) == 0 {
|
||||
return false, fmt.Errorf("peer's does not contain peer network range addresses")
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
@ -137,7 +138,7 @@ func TestPeerNetworkRangeCheck_Check(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
isValid, err := tt.check.Check(tt.peer)
|
||||
isValid, err := tt.check.Check(context.Background(), tt.peer)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@ -28,20 +29,20 @@ type OSVersionCheck struct {
|
||||
|
||||
var _ Check = (*OSVersionCheck)(nil)
|
||||
|
||||
func (c *OSVersionCheck) Check(peer nbpeer.Peer) (bool, error) {
|
||||
func (c *OSVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) {
|
||||
peerGoOS := peer.Meta.GoOS
|
||||
switch peerGoOS {
|
||||
case "android":
|
||||
return checkMinVersion(peerGoOS, peer.Meta.OSVersion, c.Android)
|
||||
return checkMinVersion(ctx, peerGoOS, peer.Meta.OSVersion, c.Android)
|
||||
case "darwin":
|
||||
return checkMinVersion(peerGoOS, peer.Meta.OSVersion, c.Darwin)
|
||||
return checkMinVersion(ctx, peerGoOS, peer.Meta.OSVersion, c.Darwin)
|
||||
case "ios":
|
||||
return checkMinVersion(peerGoOS, peer.Meta.OSVersion, c.Ios)
|
||||
return checkMinVersion(ctx, peerGoOS, peer.Meta.OSVersion, c.Ios)
|
||||
case "linux":
|
||||
kernelVersion := strings.Split(peer.Meta.KernelVersion, "-")[0]
|
||||
return checkMinKernelVersion(peerGoOS, kernelVersion, c.Linux)
|
||||
return checkMinKernelVersion(ctx, peerGoOS, kernelVersion, c.Linux)
|
||||
case "windows":
|
||||
return checkMinKernelVersion(peerGoOS, peer.Meta.KernelVersion, c.Windows)
|
||||
return checkMinKernelVersion(ctx, peerGoOS, peer.Meta.KernelVersion, c.Windows)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
@ -79,9 +80,9 @@ func (c *OSVersionCheck) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkMinVersion(peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) {
|
||||
func checkMinVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinVersionCheck) (bool, error) {
|
||||
if check == nil {
|
||||
log.Debugf("peer %s OS is not allowed in the check", peerGoOS)
|
||||
log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@ -99,14 +100,14 @@ func checkMinVersion(peerGoOS, peerVersion string, check *MinVersionCheck) (bool
|
||||
return true, nil
|
||||
}
|
||||
|
||||
log.Debugf("peer %s OS version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinVersion)
|
||||
log.WithContext(ctx).Debugf("peer %s OS version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinVersion)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func checkMinKernelVersion(peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) {
|
||||
func checkMinKernelVersion(ctx context.Context, peerGoOS, peerVersion string, check *MinKernelVersionCheck) (bool, error) {
|
||||
if check == nil {
|
||||
log.Debugf("peer %s OS is not allowed in the check", peerGoOS)
|
||||
log.WithContext(ctx).Debugf("peer %s OS is not allowed in the check", peerGoOS)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@ -124,7 +125,7 @@ func checkMinKernelVersion(peerGoOS, peerVersion string, check *MinKernelVersion
|
||||
return true, nil
|
||||
}
|
||||
|
||||
log.Debugf("peer %s kernel version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinKernelVersion)
|
||||
log.WithContext(ctx).Debugf("peer %s kernel version %s is older than minimum allowed version %s", peerGoOS, peerVersion, check.MinKernelVersion)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user