Merge branch 'main' into fix/remove-ids-from-policy-creation

This commit is contained in:
Pascal Fischer 2024-12-11 13:34:51 +01:00
commit 3a95966ccc
43 changed files with 721 additions and 449 deletions

View File

@ -46,6 +46,7 @@ type ConfigInput struct {
ManagementURL string ManagementURL string
AdminURL string AdminURL string
ConfigPath string ConfigPath string
StateFilePath string
PreSharedKey *string PreSharedKey *string
ServerSSHAllowed *bool ServerSSHAllowed *bool
NATExternalIPs []string NATExternalIPs []string
@ -105,10 +106,10 @@ type Config struct {
// DNSRouteInterval is the interval in which the DNS routes are updated // DNSRouteInterval is the interval in which the DNS routes are updated
DNSRouteInterval time.Duration DNSRouteInterval time.Duration
//Path to a certificate used for mTLS authentication // Path to a certificate used for mTLS authentication
ClientCertPath string ClientCertPath string
//Path to corresponding private key of ClientCertPath // Path to corresponding private key of ClientCertPath
ClientCertKeyPath string ClientCertKeyPath string
ClientCertKeyPair *tls.Certificate `json:"-"` ClientCertKeyPair *tls.Certificate `json:"-"`
@ -116,7 +117,7 @@ type Config struct {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
func ReadConfig(configPath string) (*Config, error) { func ReadConfig(configPath string) (*Config, error) {
if configFileIsExists(configPath) { if fileExists(configPath) {
err := util.EnforcePermission(configPath) err := util.EnforcePermission(configPath)
if err != nil { if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err) log.Errorf("failed to enforce permission on config dir: %v", err)
@ -149,7 +150,7 @@ func ReadConfig(configPath string) (*Config, error) {
// UpdateConfig update existing configuration according to input configuration and return with the configuration // UpdateConfig update existing configuration according to input configuration and return with the configuration
func UpdateConfig(input ConfigInput) (*Config, error) { func UpdateConfig(input ConfigInput) (*Config, error) {
if !configFileIsExists(input.ConfigPath) { if !fileExists(input.ConfigPath) {
return nil, status.Errorf(codes.NotFound, "config file doesn't exist") return nil, status.Errorf(codes.NotFound, "config file doesn't exist")
} }
@ -158,7 +159,7 @@ func UpdateConfig(input ConfigInput) (*Config, error) {
// UpdateOrCreateConfig reads existing config or generates a new one // UpdateOrCreateConfig reads existing config or generates a new one
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if !configFileIsExists(input.ConfigPath) { if !fileExists(input.ConfigPath) {
log.Infof("generating new config %s", input.ConfigPath) log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input) cfg, err := createNewConfig(input)
if err != nil { if err != nil {
@ -472,11 +473,19 @@ func isPreSharedKeyHidden(preSharedKey *string) bool {
return false return false
} }
func configFileIsExists(path string) bool { func fileExists(path string) bool {
_, err := os.Stat(path) _, err := os.Stat(path)
return !os.IsNotExist(err) return !os.IsNotExist(err)
} }
func createFile(path string) error {
file, err := os.Create(path)
if err != nil {
return err
}
return file.Close()
}
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain. // UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config. // If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
// The check is performed only for the NetBird's managed version. // The check is performed only for the NetBird's managed version.

View File

@ -91,6 +91,7 @@ func (c *ConnectClient) RunOniOS(
fileDescriptor int32, fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener, networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager, dnsManager dns.IosDnsManager,
stateFilePath string,
) error { ) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension. // Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
debug.SetGCPercent(5) debug.SetGCPercent(5)
@ -99,6 +100,7 @@ func (c *ConnectClient) RunOniOS(
FileDescriptor: fileDescriptor, FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener, NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager, DnsManager: dnsManager,
StateFilePath: stateFilePath,
} }
return c.run(mobileDependency, nil, nil) return c.run(mobileDependency, nil, nil)
} }

View File

@ -39,6 +39,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
@ -62,6 +63,7 @@ import (
const ( const (
PeerConnectionTimeoutMax = 45000 // ms PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms PeerConnectionTimeoutMin = 30000 // ms
connInitLimit = 200
) )
var ErrResetConnection = fmt.Errorf("reset connection") var ErrResetConnection = fmt.Errorf("reset connection")
@ -177,6 +179,7 @@ type Engine struct {
// Network map persistence // Network map persistence
persistNetworkMap bool persistNetworkMap bool
latestNetworkMap *mgmProto.NetworkMap latestNetworkMap *mgmProto.NetworkMap
connSemaphore *semaphoregroup.SemaphoreGroup
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@ -242,6 +245,18 @@ func NewEngineWithProbes(
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
probes: probes, probes: probes,
checks: checks, checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
}
if runtime.GOOS == "ios" {
if !fileExists(mobileDep.StateFilePath) {
err := createFile(mobileDep.StateFilePath)
if err != nil {
log.Errorf("failed to create state file: %v", err)
// we are not exiting as we can run without the state manager
}
}
engine.stateManager = statemanager.New(mobileDep.StateFilePath)
} }
if path := statemanager.GetDefaultStatePath(); path != "" { if path := statemanager.GetDefaultStatePath(); path != "" {
engine.stateManager = statemanager.New(path) engine.stateManager = statemanager.New(path)
@ -1040,7 +1055,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
}, },
} }
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher) peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher, e.connSemaphore)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -19,4 +19,5 @@ type MobileDependency struct {
// iOS only // iOS only
DnsManager dns.IosDnsManager DnsManager dns.IosDnsManager
FileDescriptor int32 FileDescriptor int32
StateFilePath string
} }

View File

@ -23,6 +23,7 @@ import (
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
type ConnPriority int type ConnPriority int
@ -104,12 +105,13 @@ type Conn struct {
wgProxyICE wgproxy.Proxy wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy wgProxyRelay wgproxy.Proxy
guard *guard.Guard guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup
} }
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open // To establish a connection run Conn.Open
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) { func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) {
allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps) allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps)
if err != nil { if err != nil {
log.Errorf("failed to parse allowedIPS: %v", err) log.Errorf("failed to parse allowedIPS: %v", err)
@ -130,6 +132,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
allowedIP: allowedIP, allowedIP: allowedIP,
statusRelay: NewAtomicConnStatus(), statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(),
semaphore: semaphore,
} }
rFns := WorkerRelayCallbacks{ rFns := WorkerRelayCallbacks{
@ -169,6 +172,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will // It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used. // be used.
func (conn *Conn) Open() { func (conn *Conn) Open() {
conn.semaphore.Add(conn.ctx)
conn.log.Debugf("open connection to peer") conn.log.Debugf("open connection to peer")
conn.mu.Lock() conn.mu.Lock()
@ -191,6 +195,7 @@ func (conn *Conn) Open() {
} }
func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) { func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
defer conn.semaphore.Done(conn.ctx)
conn.waitInitialRandomSleepTime(ctx) conn.waitInitialRandomSleepTime(ctx)
err := conn.handshaker.sendOffer() err := conn.handshaker.sendOffer()

View File

@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
var connConf = ConnConfig{ var connConf = ConnConfig{
@ -46,7 +47,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
func TestConn_GetKey(t *testing.T) { func TestConn_GetKey(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher) conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil { if err != nil {
return return
} }
@ -58,7 +59,7 @@ func TestConn_GetKey(t *testing.T) {
func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil { if err != nil {
return return
} }
@ -92,7 +93,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil { if err != nil {
return return
} }
@ -125,7 +126,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
} }
func TestConn_Status(t *testing.T) { func TestConn_Status(t *testing.T) {
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
if err != nil { if err != nil {
return return
} }

View File

@ -273,3 +273,88 @@ func TestRouteSelector_FilterSelected(t *testing.T) {
"route2|192.168.0.0/16": {}, "route2|192.168.0.0/16": {},
}, filtered) }, filtered)
} }
func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
initialRoutes := []route.NetID{"route1", "route2", "route3"}
newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"}
tests := []struct {
name string
initialState func(rs *routeselector.RouteSelector) error // Setup initial state
wantNewSelected []route.NetID // Expected selected routes after new routes appear
}{
{
name: "New routes with initial selectAll state",
initialState: func(rs *routeselector.RouteSelector) error {
rs.SelectAllRoutes()
return nil
},
// When selectAll is true, all routes including new ones should be selected
wantNewSelected: []route.NetID{"route1", "route2", "route3", "route4", "route5"},
},
{
name: "New routes after specific selection",
initialState: func(rs *routeselector.RouteSelector) error {
return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, initialRoutes)
},
// When specific routes were selected, new routes should remain unselected
wantNewSelected: []route.NetID{"route1", "route2"},
},
{
name: "New routes after deselect all",
initialState: func(rs *routeselector.RouteSelector) error {
rs.DeselectAllRoutes()
return nil
},
// After deselect all, new routes should remain unselected
wantNewSelected: []route.NetID{},
},
{
name: "New routes after deselecting specific routes",
initialState: func(rs *routeselector.RouteSelector) error {
rs.SelectAllRoutes()
return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes)
},
// After deselecting specific routes, new routes should remain unselected
wantNewSelected: []route.NetID{"route2", "route3"},
},
{
name: "New routes after selecting with append",
initialState: func(rs *routeselector.RouteSelector) error {
return rs.SelectRoutes([]route.NetID{"route1"}, true, initialRoutes)
},
// When routes were appended, new routes should remain unselected
wantNewSelected: []route.NetID{"route1"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
// Setup initial state
err := tt.initialState(rs)
require.NoError(t, err)
// Verify selection state with new routes
for _, id := range newRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantNewSelected, id),
"Route %s selection state incorrect", id)
}
// Additional verification using FilterSelected
routes := route.HAMap{
"route1|10.0.0.0/8": {},
"route2|192.168.0.0/16": {},
"route3|172.16.0.0/12": {},
"route4|10.10.0.0/16": {},
"route5|192.168.1.0/24": {},
}
filtered := rs.FilterSelected(routes)
expectedLen := len(tt.wantNewSelected)
assert.Equal(t, expectedLen, len(filtered),
"FilterSelected returned wrong number of routes, got %d want %d", len(filtered), expectedLen)
})
}
}

View File

@ -59,6 +59,7 @@ func init() {
// Client struct manage the life circle of background service // Client struct manage the life circle of background service
type Client struct { type Client struct {
cfgFile string cfgFile string
stateFile string
recorder *peer.Status recorder *peer.Status
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
ctxCancelLock *sync.Mutex ctxCancelLock *sync.Mutex
@ -73,9 +74,10 @@ type Client struct {
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
func NewClient(cfgFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client { func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
return &Client{ return &Client{
cfgFile: cfgFile, cfgFile: cfgFile,
stateFile: stateFile,
deviceName: deviceName, deviceName: deviceName,
osName: osName, osName: osName,
osVersion: osVersion, osVersion: osVersion,
@ -91,7 +93,8 @@ func (c *Client) Run(fd int32, interfaceName string) error {
log.Infof("Starting NetBird client") log.Infof("Starting NetBird client")
log.Debugf("Tunnel uses interface: %s", interfaceName) log.Debugf("Tunnel uses interface: %s", interfaceName)
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{ cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
ConfigPath: c.cfgFile, ConfigPath: c.cfgFile,
StateFilePath: c.stateFile,
}) })
if err != nil { if err != nil {
return err return err
@ -124,7 +127,7 @@ func (c *Client) Run(fd int32, interfaceName string) error {
cfg.WgIface = interfaceName cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager) return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources

View File

@ -10,9 +10,10 @@ type Preferences struct {
} }
// NewPreferences create new Preferences instance // NewPreferences create new Preferences instance
func NewPreferences(configPath string) *Preferences { func NewPreferences(configPath string, stateFilePath string) *Preferences {
ci := internal.ConfigInput{ ci := internal.ConfigInput{
ConfigPath: configPath, ConfigPath: configPath,
StateFilePath: stateFilePath,
} }
return &Preferences{ci} return &Preferences{ci}
} }

View File

@ -9,7 +9,8 @@ import (
func TestPreferences_DefaultValues(t *testing.T) { func TestPreferences_DefaultValues(t *testing.T) {
cfgFile := filepath.Join(t.TempDir(), "netbird.json") cfgFile := filepath.Join(t.TempDir(), "netbird.json")
p := NewPreferences(cfgFile) stateFile := filepath.Join(t.TempDir(), "state.json")
p := NewPreferences(cfgFile, stateFile)
defaultVar, err := p.GetAdminURL() defaultVar, err := p.GetAdminURL()
if err != nil { if err != nil {
t.Fatalf("failed to read default value: %s", err) t.Fatalf("failed to read default value: %s", err)
@ -42,7 +43,8 @@ func TestPreferences_DefaultValues(t *testing.T) {
func TestPreferences_ReadUncommitedValues(t *testing.T) { func TestPreferences_ReadUncommitedValues(t *testing.T) {
exampleString := "exampleString" exampleString := "exampleString"
cfgFile := filepath.Join(t.TempDir(), "netbird.json") cfgFile := filepath.Join(t.TempDir(), "netbird.json")
p := NewPreferences(cfgFile) stateFile := filepath.Join(t.TempDir(), "state.json")
p := NewPreferences(cfgFile, stateFile)
p.SetAdminURL(exampleString) p.SetAdminURL(exampleString)
resp, err := p.GetAdminURL() resp, err := p.GetAdminURL()
@ -79,7 +81,8 @@ func TestPreferences_Commit(t *testing.T) {
exampleURL := "https://myurl.com:443" exampleURL := "https://myurl.com:443"
examplePresharedKey := "topsecret" examplePresharedKey := "topsecret"
cfgFile := filepath.Join(t.TempDir(), "netbird.json") cfgFile := filepath.Join(t.TempDir(), "netbird.json")
p := NewPreferences(cfgFile) stateFile := filepath.Join(t.TempDir(), "state.json")
p := NewPreferences(cfgFile, stateFile)
p.SetAdminURL(exampleURL) p.SetAdminURL(exampleURL)
p.SetManagementURL(exampleURL) p.SetManagementURL(exampleURL)
@ -90,7 +93,7 @@ func TestPreferences_Commit(t *testing.T) {
t.Fatalf("failed to save changes: %s", err) t.Fatalf("failed to save changes: %s", err)
} }
p = NewPreferences(cfgFile) p = NewPreferences(cfgFile, stateFile)
resp, err := p.GetAdminURL() resp, err := p.GetAdminURL()
if err != nil { if err != nil {
t.Fatalf("failed to read admin url: %s", err) t.Fatalf("failed to read admin url: %s", err)

View File

@ -42,6 +42,7 @@ import (
nbContext "github.com/netbirdio/netbird/management/server/context" nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
httpapi "github.com/netbirdio/netbird/management/server/http" httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/metrics" "github.com/netbirdio/netbird/management/server/metrics"
@ -257,7 +258,7 @@ var (
return fmt.Errorf("failed creating JWT validator: %v", err) return fmt.Errorf("failed creating JWT validator: %v", err)
} }
httpAPIAuthCfg := httpapi.AuthCfg{ httpAPIAuthCfg := configs.AuthCfg{
Issuer: config.HttpConfig.AuthIssuer, Issuer: config.HttpConfig.AuthIssuer,
Audience: config.HttpConfig.AuthAudience, Audience: config.HttpConfig.AuthAudience,
UserIDClaim: config.HttpConfig.AuthUserIDClaim, UserIDClaim: config.HttpConfig.AuthUserIDClaim,

View File

@ -0,0 +1,9 @@
package configs
// AuthCfg contains parameters for authentication middleware
type AuthCfg struct {
Issuer string
Audience string
UserIDClaim string
KeysLocation string
}

View File

@ -12,6 +12,16 @@ import (
s "github.com/netbirdio/netbird/management/server" s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/handlers/accounts"
"github.com/netbirdio/netbird/management/server/http/handlers/dns"
"github.com/netbirdio/netbird/management/server/http/handlers/events"
"github.com/netbirdio/netbird/management/server/http/handlers/groups"
"github.com/netbirdio/netbird/management/server/http/handlers/peers"
"github.com/netbirdio/netbird/management/server/http/handlers/policies"
"github.com/netbirdio/netbird/management/server/http/handlers/routes"
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
"github.com/netbirdio/netbird/management/server/http/handlers/users"
"github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/integrated_validator" "github.com/netbirdio/netbird/management/server/integrated_validator"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@ -20,27 +30,15 @@ import (
const apiPrefix = "/api" const apiPrefix = "/api"
// AuthCfg contains parameters for authentication middleware
type AuthCfg struct {
Issuer string
Audience string
UserIDClaim string
KeysLocation string
}
type apiHandler struct { type apiHandler struct {
Router *mux.Router Router *mux.Router
AccountManager s.AccountManager AccountManager s.AccountManager
geolocationManager *geolocation.Geolocation geolocationManager *geolocation.Geolocation
AuthCfg AuthCfg AuthCfg configs.AuthCfg
}
// EmptyObject is an empty struct used to return empty JSON object
type emptyObject struct {
} }
// APIHandler creates the Management service HTTP API handler registering all the available endpoints. // APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor( claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
@ -86,122 +84,15 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
return nil, fmt.Errorf("register integrations endpoints: %w", err) return nil, fmt.Errorf("register integrations endpoints: %w", err)
} }
api.addAccountsEndpoint() accounts.AddEndpoints(api.AccountManager, authCfg, router)
api.addPeersEndpoint() peers.AddEndpoints(api.AccountManager, authCfg, router)
api.addUsersEndpoint() users.AddEndpoints(api.AccountManager, authCfg, router)
api.addUsersTokensEndpoint() setup_keys.AddEndpoints(api.AccountManager, authCfg, router)
api.addSetupKeysEndpoint() policies.AddEndpoints(api.AccountManager, api.geolocationManager, authCfg, router)
api.addPoliciesEndpoint() groups.AddEndpoints(api.AccountManager, authCfg, router)
api.addGroupsEndpoint() routes.AddEndpoints(api.AccountManager, authCfg, router)
api.addRoutesEndpoint() dns.AddEndpoints(api.AccountManager, authCfg, router)
api.addDNSNameserversEndpoint() events.AddEndpoints(api.AccountManager, authCfg, router)
api.addDNSSettingEndpoint()
api.addEventsEndpoint()
api.addPostureCheckEndpoint()
api.addLocationsEndpoint()
return rootRouter, nil return rootRouter, nil
} }
func (apiHandler *apiHandler) addAccountsEndpoint() {
accountsHandler := NewAccountsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/accounts/{accountId}", accountsHandler.UpdateAccount).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/accounts/{accountId}", accountsHandler.DeleteAccount).Methods("DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/accounts", accountsHandler.GetAllAccounts).Methods("GET", "OPTIONS")
}
func (apiHandler *apiHandler) addPeersEndpoint() {
peersHandler := NewPeersHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
}
func (apiHandler *apiHandler) addUsersEndpoint() {
userHandler := NewUsersHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/users", userHandler.GetAllUsers).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS")
apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}/invite", userHandler.InviteUser).Methods("POST", "OPTIONS")
}
func (apiHandler *apiHandler) addUsersTokensEndpoint() {
tokenHandler := NewPATsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.GetAllTokens).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.CreateToken).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.GetToken).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.DeleteToken).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addSetupKeysEndpoint() {
keysHandler := NewSetupKeysHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeys).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.DeleteSetupKey).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addPoliciesEndpoint() {
policiesHandler := NewPoliciesHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/policies", policiesHandler.GetAllPolicies).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/policies", policiesHandler.CreatePolicy).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.UpdatePolicy).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.GetPolicy).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.DeletePolicy).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addGroupsEndpoint() {
groupsHandler := NewGroupsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/groups", groupsHandler.GetAllGroups).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/groups", groupsHandler.CreateGroup).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.UpdateGroup).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.GetGroup).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.DeleteGroup).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addRoutesEndpoint() {
routesHandler := NewRoutesHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/routes", routesHandler.GetAllRoutes).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/routes", routesHandler.CreateRoute).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.UpdateRoute).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.GetRoute).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.DeleteRoute).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addDNSNameserversEndpoint() {
nameserversHandler := NewNameserversHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.GetAllNameservers).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.CreateNameserverGroup).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.UpdateNameserverGroup).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.GetNameserverGroup).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.DeleteNameserverGroup).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addDNSSettingEndpoint() {
dnsSettingsHandler := NewDNSSettingsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/dns/settings", dnsSettingsHandler.GetDNSSettings).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/dns/settings", dnsSettingsHandler.UpdateDNSSettings).Methods("PUT", "OPTIONS")
}
func (apiHandler *apiHandler) addEventsEndpoint() {
eventsHandler := NewEventsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/events", eventsHandler.GetAllEvents).Methods("GET", "OPTIONS")
}
func (apiHandler *apiHandler) addPostureCheckEndpoint() {
postureCheckHandler := NewPostureChecksHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.GetAllPostureChecks).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.CreatePostureCheck).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.UpdatePostureCheck).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.GetPostureCheck).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.DeletePostureCheck).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addLocationsEndpoint() {
locationHandler := NewGeolocationsHandlerHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/locations/countries", locationHandler.GetAllCountries).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/locations/countries/{country}/cities", locationHandler.GetCitiesByCountry).Methods("GET", "OPTIONS")
}

View File

@ -1,4 +1,4 @@
package http package accounts
import ( import (
"encoding/json" "encoding/json"
@ -10,20 +10,28 @@ import (
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
// AccountsHandler is a handler that handles the server.Account HTTP endpoints // handler is a handler that handles the server.Account HTTP endpoints
type AccountsHandler struct { type handler struct {
accountManager server.AccountManager accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewAccountsHandler creates a new AccountsHandler HTTP handler func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *AccountsHandler { accountsHandler := newHandler(accountManager, authCfg)
return &AccountsHandler{ router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
}
// newHandler creates a new handler HTTP handler
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
return &handler{
accountManager: accountManager, accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
@ -32,8 +40,8 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
} }
} }
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. // getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -51,8 +59,8 @@ func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
} }
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) // updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) { func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
_, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) _, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -111,8 +119,8 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
util.WriteJSONObject(r.Context(), w, &resp) util.WriteJSONObject(r.Context(), w, &resp)
} }
// DeleteAccount is a HTTP DELETE handler to delete an account // deleteAccount is a HTTP DELETE handler to delete an account
func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) { func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
vars := mux.Vars(r) vars := mux.Vars(r)
targetAccountID := vars["accountId"] targetAccountID := vars["accountId"]
@ -127,7 +135,7 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
return return
} }
util.WriteJSONObject(r.Context(), w, emptyObject{}) util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
} }
func toAccountResponse(accountID string, settings *server.Settings) *api.Account { func toAccountResponse(accountID string, settings *server.Settings) *api.Account {

View File

@ -1,4 +1,4 @@
package http package accounts
import ( import (
"bytes" "bytes"
@ -20,8 +20,8 @@ import (
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler { func initAccountsTestData(account *server.Account, admin *server.User) *handler {
return &AccountsHandler{ return &handler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return account.Id, admin.Id, nil return account.Id, admin.Id, nil
@ -89,7 +89,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
requestBody io.Reader requestBody io.Reader
}{ }{
{ {
name: "GetAllAccounts OK", name: "getAllAccounts OK",
expectedBody: true, expectedBody: true,
requestType: http.MethodGet, requestType: http.MethodGet,
requestPath: "/api/accounts", requestPath: "/api/accounts",
@ -189,8 +189,8 @@ func TestAccounts_AccountsHandler(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/accounts", handler.GetAllAccounts).Methods("GET") router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET")
router.HandleFunc("/api/accounts/{accountId}", handler.UpdateAccount).Methods("PUT") router.HandleFunc("/api/accounts/{accountId}", handler.updateAccount).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@ -1,26 +1,39 @@
package http package dns
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
) )
// DNSSettingsHandler is a handler that returns the DNS settings of the account // dnsSettingsHandler is a handler that returns the DNS settings of the account
type DNSSettingsHandler struct { type dnsSettingsHandler struct {
accountManager server.AccountManager accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewDNSSettingsHandler returns a new instance of DNSSettingsHandler handler func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg) *DNSSettingsHandler { addDNSSettingEndpoint(accountManager, authCfg, router)
return &DNSSettingsHandler{ addDNSNameserversEndpoint(accountManager, authCfg, router)
}
func addDNSSettingEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
dnsSettingsHandler := newDNSSettingsHandler(accountManager, authCfg)
router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS")
}
// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler
func newDNSSettingsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *dnsSettingsHandler {
return &dnsSettingsHandler{
accountManager: accountManager, accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
@ -29,8 +42,8 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
} }
} }
// GetDNSSettings returns the DNS settings for the account // getDNSSettings returns the DNS settings for the account
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) { func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -52,8 +65,8 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
util.WriteJSONObject(r.Context(), w, apiDNSSettings) util.WriteJSONObject(r.Context(), w, apiDNSSettings)
} }
// UpdateDNSSettings handles update to DNS settings of an account // updateDNSSettings handles update to DNS settings of an account
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {

View File

@ -1,4 +1,4 @@
package http package dns
import ( import (
"bytes" "bytes"
@ -40,8 +40,8 @@ var testingDNSSettingsAccount = &server.Account{
DNSSettings: baseExistingDNSSettings, DNSSettings: baseExistingDNSSettings,
} }
func initDNSSettingsTestData() *DNSSettingsHandler { func initDNSSettingsTestData() *dnsSettingsHandler {
return &DNSSettingsHandler{ return &dnsSettingsHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) { GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) {
return &testingDNSSettingsAccount.DNSSettings, nil return &testingDNSSettingsAccount.DNSSettings, nil
@ -120,8 +120,8 @@ func TestDNSSettingsHandlers(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/dns/settings", p.GetDNSSettings).Methods("GET") router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET")
router.HandleFunc("/api/dns/settings", p.UpdateDNSSettings).Methods("PUT") router.HandleFunc("/api/dns/settings", p.updateDNSSettings).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@ -1,4 +1,4 @@
package http package dns
import ( import (
"encoding/json" "encoding/json"
@ -11,20 +11,30 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
// NameserversHandler is the nameserver group handler of the account // nameserversHandler is the nameserver group handler of the account
type NameserversHandler struct { type nameserversHandler struct {
accountManager server.AccountManager accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewNameserversHandler returns a new instance of NameserversHandler handler func addDNSNameserversEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg) *NameserversHandler { nameserversHandler := newNameserversHandler(accountManager, authCfg)
return &NameserversHandler{ router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.getNameserverGroup).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.deleteNameserverGroup).Methods("DELETE", "OPTIONS")
}
// newNameserversHandler returns a new instance of nameserversHandler handler
func newNameserversHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *nameserversHandler {
return &nameserversHandler{
accountManager: accountManager, accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
@ -33,8 +43,8 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg
} }
} }
// GetAllNameservers returns the list of nameserver groups for the account // getAllNameservers returns the list of nameserver groups for the account
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) { func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -57,8 +67,8 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
util.WriteJSONObject(r.Context(), w, apiNameservers) util.WriteJSONObject(r.Context(), w, apiNameservers)
} }
// CreateNameserverGroup handles nameserver group creation request // createNameserverGroup handles nameserver group creation request
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -90,8 +100,8 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
util.WriteJSONObject(r.Context(), w, &resp) util.WriteJSONObject(r.Context(), w, &resp)
} }
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID // updateNameserverGroup handles update to a nameserver group identified by a given ID
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -141,8 +151,8 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
util.WriteJSONObject(r.Context(), w, &resp) util.WriteJSONObject(r.Context(), w, &resp)
} }
// DeleteNameserverGroup handles nameserver group deletion request // deleteNameserverGroup handles nameserver group deletion request
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -162,11 +172,11 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
return return
} }
util.WriteJSONObject(r.Context(), w, emptyObject{}) util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
} }
// GetNameserverGroup handles a nameserver group Get request identified by ID // getNameserverGroup handles a nameserver group Get request identified by ID
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) { func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {

View File

@ -1,4 +1,4 @@
package http package dns
import ( import (
"bytes" "bytes"
@ -50,8 +50,8 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{
Enabled: true, Enabled: true,
} }
func initNameserversTestData() *NameserversHandler { func initNameserversTestData() *nameserversHandler {
return &NameserversHandler{ return &nameserversHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
if nsGroupID == existingNSGroupID { if nsGroupID == existingNSGroupID {
@ -206,10 +206,10 @@ func TestNameserversHandlers(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.GetNameserverGroup).Methods("GET") router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET")
router.HandleFunc("/api/dns/nameservers", p.CreateNameserverGroup).Methods("POST") router.HandleFunc("/api/dns/nameservers", p.createNameserverGroup).Methods("POST")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.DeleteNameserverGroup).Methods("DELETE") router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.deleteNameserverGroup).Methods("DELETE")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.UpdateNameserverGroup).Methods("PUT") router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.updateNameserverGroup).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@ -1,28 +1,35 @@
package http package events
import ( import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
) )
// EventsHandler HTTP handler // handler HTTP handler
type EventsHandler struct { type handler struct {
accountManager server.AccountManager accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewEventsHandler creates a new EventsHandler HTTP handler func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *EventsHandler { eventsHandler := newHandler(accountManager, authCfg)
return &EventsHandler{ router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
}
// newHandler creates a new events handler
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
return &handler{
accountManager: accountManager, accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
@ -31,8 +38,8 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev
} }
} }
// GetAllEvents list of the given account // getAllEvents list of the given account
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -60,7 +67,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, events) util.WriteJSONObject(r.Context(), w, events)
} }
func (h *EventsHandler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error { func (h *handler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error {
// build email, name maps based on users // build email, name maps based on users
userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId) userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId)
if err != nil { if err != nil {

View File

@ -1,4 +1,4 @@
package http package events
import ( import (
"context" "context"
@ -20,8 +20,8 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
) )
func initEventsTestData(account string, events ...*activity.Event) *EventsHandler { func initEventsTestData(account string, events ...*activity.Event) *handler {
return &EventsHandler{ return &handler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) { GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
if accountID == account { if accountID == account {
@ -183,7 +183,7 @@ func TestEvents_GetEvents(t *testing.T) {
requestBody io.Reader requestBody io.Reader
}{ }{
{ {
name: "GetAllEvents OK", name: "getAllEvents OK",
expectedBody: true, expectedBody: true,
requestType: http.MethodGet, requestType: http.MethodGet,
requestPath: "/api/events/", requestPath: "/api/events/",
@ -201,7 +201,7 @@ func TestEvents_GetEvents(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/events/", handler.GetAllEvents).Methods("GET") router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@ -1,13 +1,15 @@
package http package groups
import ( import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/http/configs"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
@ -16,15 +18,24 @@ import (
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
// GroupsHandler is a handler that returns groups of the account // handler is a handler that returns groups of the account
type GroupsHandler struct { type handler struct {
accountManager server.AccountManager accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewGroupsHandler creates a new GroupsHandler HTTP handler func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *GroupsHandler { groupsHandler := newHandler(accountManager, authCfg)
return &GroupsHandler{ router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS")
router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS")
router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS")
router.HandleFunc("/groups/{groupId}", groupsHandler.getGroup).Methods("GET", "OPTIONS")
router.HandleFunc("/groups/{groupId}", groupsHandler.deleteGroup).Methods("DELETE", "OPTIONS")
}
// newHandler creates a new groups handler
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
return &handler{
accountManager: accountManager, accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
@ -33,8 +44,8 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr
} }
} }
// GetAllGroups list for the account // getAllGroups list for the account
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -63,8 +74,8 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, groupsResponse) util.WriteJSONObject(r.Context(), w, groupsResponse)
} }
// UpdateGroup handles update to a group identified by a given ID // updateGroup handles update to a group identified by a given ID
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -141,8 +152,8 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
} }
// CreateGroup handles group creation request // createGroup handles group creation request
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -189,8 +200,8 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
} }
// DeleteGroup handles group deletion request // deleteGroup handles group deletion request
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -215,11 +226,11 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
return return
} }
util.WriteJSONObject(r.Context(), w, emptyObject{}) util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
} }
// GetGroup returns a group // getGroup returns a group
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {

View File

@ -1,4 +1,4 @@
package http package groups
import ( import (
"bytes" "bytes"
@ -31,8 +31,8 @@ var TestPeers = map[string]*nbpeer.Peer{
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
} }
func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler { func initGroupTestData(initGroups ...*nbgroup.Group) *handler {
return &GroupsHandler{ return &handler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error { SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
if !strings.HasPrefix(group.ID, "id-") { if !strings.HasPrefix(group.ID, "id-") {
@ -106,14 +106,14 @@ func TestGetGroup(t *testing.T) {
requestBody io.Reader requestBody io.Reader
}{ }{
{ {
name: "GetGroup OK", name: "getGroup OK",
expectedBody: true, expectedBody: true,
requestType: http.MethodGet, requestType: http.MethodGet,
requestPath: "/api/groups/idofthegroup", requestPath: "/api/groups/idofthegroup",
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
}, },
{ {
name: "GetGroup not found", name: "getGroup not found",
requestType: http.MethodGet, requestType: http.MethodGet,
requestPath: "/api/groups/notexists", requestPath: "/api/groups/notexists",
expectedStatus: http.StatusNotFound, expectedStatus: http.StatusNotFound,
@ -133,7 +133,7 @@ func TestGetGroup(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/groups/{groupId}", p.GetGroup).Methods("GET") router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@ -254,8 +254,8 @@ func TestWriteGroup(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/groups", p.CreateGroup).Methods("POST") router.HandleFunc("/api/groups", p.createGroup).Methods("POST")
router.HandleFunc("/api/groups/{groupId}", p.UpdateGroup).Methods("PUT") router.HandleFunc("/api/groups/{groupId}", p.updateGroup).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@ -331,7 +331,7 @@ func TestDeleteGroup(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/groups/{groupId}", p.DeleteGroup).Methods("DELETE") router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@ -1,4 +1,4 @@
package http package peers
import ( import (
"context" "context"
@ -12,21 +12,30 @@ import (
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
// PeersHandler is a handler that returns peers of the account // Handler is a handler that returns peers of the account
type PeersHandler struct { type Handler struct {
accountManager server.AccountManager accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewPeersHandler creates a new PeersHandler HTTP handler func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *PeersHandler { peersHandler := NewHandler(accountManager, authCfg)
return &PeersHandler{ router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
}
// NewHandler creates a new peers Handler
func NewHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *Handler {
return &Handler{
accountManager: accountManager, accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
@ -35,7 +44,7 @@ func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Pee
} }
} }
func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) {
peerToReturn := peer.Copy() peerToReturn := peer.Copy()
if peer.Status.Connected { if peer.Status.Connected {
// Although we have online status in store we do not yet have an updated channel so have to show it as disconnected // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected
@ -48,7 +57,7 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error)
return peerToReturn, nil return peerToReturn, nil
} }
func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { func (h *Handler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID) peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID)
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(ctx, err, w)
@ -75,7 +84,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
} }
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { func (h *Handler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
req := &api.PeerRequest{} req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req) err := json.NewDecoder(r.Body).Decode(&req)
if err != nil { if err != nil {
@ -120,18 +129,18 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid)) util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid))
} }
func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID) err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to delete peer: %v", err) log.WithContext(ctx).Errorf("failed to delete peer: %v", err)
util.WriteError(ctx, err, w) util.WriteError(ctx, err, w)
return return
} }
util.WriteJSONObject(ctx, w, emptyObject{}) util.WriteJSONObject(ctx, w, util.EmptyObject{})
} }
// HandlePeer handles all peer requests for GET, PUT and DELETE operations // HandlePeer handles all peer requests for GET, PUT and DELETE operations
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -168,7 +177,7 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
} }
// GetAllPeers returns a list of all peers associated with a provided account // GetAllPeers returns a list of all peers associated with a provided account
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -219,7 +228,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, respBody) util.WriteJSONObject(r.Context(), w, respBody)
} }
func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
for _, peer := range respBody { for _, peer := range respBody {
_, ok := approvedPeersMap[peer.Id] _, ok := approvedPeersMap[peer.Id]
if !ok { if !ok {
@ -229,7 +238,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
} }
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {

View File

@ -1,4 +1,4 @@
package http package peers
import ( import (
"bytes" "bytes"
@ -38,8 +38,8 @@ const (
userIDKey ctxKey = "user_id" userIDKey ctxKey = "user_id"
) )
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
return &PeersHandler{ return &Handler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
var p *nbpeer.Peer var p *nbpeer.Peer

View File

@ -1,4 +1,4 @@
package http package policies
import ( import (
"context" "context"
@ -11,9 +11,9 @@ import (
"testing" "testing"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@ -21,12 +21,12 @@ import (
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
func initGeolocationTestData(t *testing.T) *GeolocationsHandler { func initGeolocationTestData(t *testing.T) *geolocationsHandler {
t.Helper() t.Helper()
var ( var (
mmdbPath = "../testdata/GeoLite2-City_20240305.mmdb" mmdbPath = "../../../testdata/GeoLite2-City_20240305.mmdb"
geonamesdbPath = "../testdata/geonames_20240305.db" geonamesdbPath = "../../../testdata/geonames_20240305.db"
) )
tempDir := t.TempDir() tempDir := t.TempDir()
@ -41,7 +41,7 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
assert.NoError(t, err) assert.NoError(t, err)
t.Cleanup(func() { _ = geo.Stop() }) t.Cleanup(func() { _ = geo.Stop() })
return &GeolocationsHandler{ return &geolocationsHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil return claims.AccountId, claims.UserId, nil
@ -114,7 +114,7 @@ func TestGetCitiesByCountry(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.GetCitiesByCountry).Methods("GET") router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@ -202,7 +202,7 @@ func TestGetAllCountries(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/locations/countries", geolocationHandler.GetAllCountries).Methods("GET") router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@ -1,4 +1,4 @@
package http package policies
import ( import (
"net/http" "net/http"
@ -9,6 +9,7 @@ import (
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@ -18,16 +19,22 @@ var (
countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$") countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$")
) )
// GeolocationsHandler is a handler that returns locations. // geolocationsHandler is a handler that returns locations.
type GeolocationsHandler struct { type geolocationsHandler struct {
accountManager server.AccountManager accountManager server.AccountManager
geolocationManager *geolocation.Geolocation geolocationManager *geolocation.Geolocation
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewGeolocationsHandlerHandler creates a new Geolocations handler func addLocationsEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *GeolocationsHandler { locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg)
return &GeolocationsHandler{ router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS")
router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS")
}
// newGeolocationsHandlerHandler creates a new Geolocations handler
func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler {
return &geolocationsHandler{
accountManager: accountManager, accountManager: accountManager,
geolocationManager: geolocationManager, geolocationManager: geolocationManager,
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
@ -37,8 +44,8 @@ func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geoloca
} }
} }
// GetAllCountries retrieves a list of all countries // getAllCountries retrieves a list of all countries
func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Request) { func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil { if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -63,8 +70,8 @@ func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Req
util.WriteJSONObject(r.Context(), w, countries) util.WriteJSONObject(r.Context(), w, countries)
} }
// GetCitiesByCountry retrieves a list of cities based on the given country code // getCitiesByCountry retrieves a list of cities based on the given country code
func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.Request) { func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil { if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@ -96,7 +103,7 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
util.WriteJSONObject(r.Context(), w, cities) util.WriteJSONObject(r.Context(), w, cities)
} }
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error { func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
claims := l.claimsExtractor.FromRequestContext(r) claims := l.claimsExtractor.FromRequestContext(r)
_, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims) _, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {

View File

@ -1,4 +1,4 @@
package http package policies
import ( import (
"encoding/json" "encoding/json"
@ -8,22 +8,34 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
// Policies is a handler that returns policy of the account // handler is a handler that returns policy of the account
type Policies struct { type handler struct {
accountManager server.AccountManager accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewPoliciesHandler creates a new Policies handler func AddEndpoints(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Policies { policiesHandler := newHandler(accountManager, authCfg)
return &Policies{ router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS")
router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS")
addPostureCheckEndpoint(accountManager, locationManager, authCfg, router)
}
// newHandler creates a new policies handler
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
return &handler{
accountManager: accountManager, accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
@ -32,8 +44,8 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *
} }
} }
// GetAllPolicies list for the account // getAllPolicies list for the account
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -66,8 +78,8 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, policies) util.WriteJSONObject(r.Context(), w, policies)
} }
// UpdatePolicy handles update to a policy identified by a given ID // updatePolicy handles update to a policy identified by a given ID
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -91,8 +103,8 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
h.savePolicy(w, r, accountID, userID, policyID) h.savePolicy(w, r, accountID, userID, policyID)
} }
// CreatePolicy handles policy creation request // createPolicy handles policy creation request
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -104,7 +116,7 @@ func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
} }
// savePolicy handles policy creation and update // savePolicy handles policy creation and update
func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) { func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) {
var req api.PutApiPoliciesPolicyIdJSONRequestBody var req api.PutApiPoliciesPolicyIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@ -251,8 +263,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
util.WriteJSONObject(r.Context(), w, resp) util.WriteJSONObject(r.Context(), w, resp)
} }
// DeletePolicy handles policy deletion request // deletePolicy handles policy deletion request
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -272,11 +284,11 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
return return
} }
util.WriteJSONObject(r.Context(), w, emptyObject{}) util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
} }
// GetPolicy handles a group Get request identified by ID // getPolicy handles a group Get request identified by ID
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {

View File

@ -1,4 +1,4 @@
package http package policies
import ( import (
"bytes" "bytes"
@ -24,12 +24,12 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
) )
func initPoliciesTestData(policies ...*server.Policy) *Policies { func initPoliciesTestData(policies ...*server.Policy) *handler {
testPolicies := make(map[string]*server.Policy, len(policies)) testPolicies := make(map[string]*server.Policy, len(policies))
for _, policy := range policies { for _, policy := range policies {
testPolicies[policy.ID] = policy testPolicies[policy.ID] = policy
} }
return &Policies{ return &handler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) { GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) {
policy, ok := testPolicies[policyID] policy, ok := testPolicies[policyID]
@ -91,14 +91,14 @@ func TestPoliciesGetPolicy(t *testing.T) {
requestBody io.Reader requestBody io.Reader
}{ }{
{ {
name: "GetPolicy OK", name: "getPolicy OK",
expectedBody: true, expectedBody: true,
requestType: http.MethodGet, requestType: http.MethodGet,
requestPath: "/api/policies/idofthepolicy", requestPath: "/api/policies/idofthepolicy",
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
}, },
{ {
name: "GetPolicy not found", name: "getPolicy not found",
requestType: http.MethodGet, requestType: http.MethodGet,
requestPath: "/api/policies/notexists", requestPath: "/api/policies/notexists",
expectedStatus: http.StatusNotFound, expectedStatus: http.StatusNotFound,
@ -121,7 +121,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/policies/{policyId}", p.GetPolicy).Methods("GET") router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@ -272,8 +272,8 @@ func TestPoliciesWritePolicy(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/policies", p.CreatePolicy).Methods("POST") router.HandleFunc("/api/policies", p.createPolicy).Methods("POST")
router.HandleFunc("/api/policies/{policyId}", p.UpdatePolicy).Methods("PUT") router.HandleFunc("/api/policies/{policyId}", p.updatePolicy).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@ -1,4 +1,4 @@
package http package policies
import ( import (
"encoding/json" "encoding/json"
@ -9,22 +9,33 @@ import (
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
// PostureChecksHandler is a handler that returns posture checks of the account. // postureChecksHandler is a handler that returns posture checks of the account.
type PostureChecksHandler struct { type postureChecksHandler struct {
accountManager server.AccountManager accountManager server.AccountManager
geolocationManager *geolocation.Geolocation geolocationManager *geolocation.Geolocation
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewPostureChecksHandler creates a new PostureChecks handler func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
func NewPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *PostureChecksHandler { postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg)
return &PostureChecksHandler{ router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS")
addLocationsEndpoint(accountManager, locationManager, authCfg, router)
}
// newPostureChecksHandler creates a new PostureChecks handler
func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler {
return &postureChecksHandler{
accountManager: accountManager, accountManager: accountManager,
geolocationManager: geolocationManager, geolocationManager: geolocationManager,
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
@ -34,8 +45,8 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa
} }
} }
// GetAllPostureChecks list for the account // getAllPostureChecks list for the account
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) { func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -57,8 +68,8 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt
util.WriteJSONObject(r.Context(), w, postureChecks) util.WriteJSONObject(r.Context(), w, postureChecks)
} }
// UpdatePostureCheck handles update to a posture check identified by a given ID // updatePostureCheck handles update to a posture check identified by a given ID
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) { func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -82,8 +93,8 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
p.savePostureChecks(w, r, accountID, userID, postureChecksID) p.savePostureChecks(w, r, accountID, userID, postureChecksID)
} }
// CreatePostureCheck handles posture check creation request // createPostureCheck handles posture check creation request
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) { func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -94,8 +105,8 @@ func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http
p.savePostureChecks(w, r, accountID, userID, "") p.savePostureChecks(w, r, accountID, userID, "")
} }
// GetPostureCheck handles a posture check Get request identified by ID // getPostureCheck handles a posture check Get request identified by ID
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) { func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -119,8 +130,8 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse()) util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse())
} }
// DeletePostureCheck handles posture check deletion request // deletePostureCheck handles posture check deletion request
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) { func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r) claims := p.claimsExtractor.FromRequestContext(r)
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -140,11 +151,11 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
return return
} }
util.WriteJSONObject(r.Context(), w, emptyObject{}) util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
} }
// savePostureChecks handles posture checks create and update // savePostureChecks handles posture checks create and update
func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) { func (p *postureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) {
var ( var (
err error err error
req api.PostureCheckUpdate req api.PostureCheckUpdate

View File

@ -1,4 +1,4 @@
package http package policies
import ( import (
"bytes" "bytes"
@ -25,13 +25,13 @@ import (
var berlin = "Berlin" var berlin = "Berlin"
var losAngeles = "Los Angeles" var losAngeles = "Los Angeles"
func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksHandler { func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksHandler {
testPostureChecks := make(map[string]*posture.Checks, len(postureChecks)) testPostureChecks := make(map[string]*posture.Checks, len(postureChecks))
for _, postureCheck := range postureChecks { for _, postureCheck := range postureChecks {
testPostureChecks[postureCheck.ID] = postureCheck testPostureChecks[postureCheck.ID] = postureCheck
} }
return &PostureChecksHandler{ return &postureChecksHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetPostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { GetPostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
p, ok := testPostureChecks[postureChecksID] p, ok := testPostureChecks[postureChecksID]
@ -147,35 +147,35 @@ func TestGetPostureCheck(t *testing.T) {
requestBody io.Reader requestBody io.Reader
}{ }{
{ {
name: "GetPostureCheck NBVersion OK", name: "getPostureCheck NBVersion OK",
expectedBody: true, expectedBody: true,
id: postureCheck.ID, id: postureCheck.ID,
checkName: postureCheck.Name, checkName: postureCheck.Name,
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
}, },
{ {
name: "GetPostureCheck OSVersion OK", name: "getPostureCheck OSVersion OK",
expectedBody: true, expectedBody: true,
id: osPostureCheck.ID, id: osPostureCheck.ID,
checkName: osPostureCheck.Name, checkName: osPostureCheck.Name,
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
}, },
{ {
name: "GetPostureCheck GeoLocation OK", name: "getPostureCheck GeoLocation OK",
expectedBody: true, expectedBody: true,
id: geoPostureCheck.ID, id: geoPostureCheck.ID,
checkName: geoPostureCheck.Name, checkName: geoPostureCheck.Name,
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
}, },
{ {
name: "GetPostureCheck PrivateNetwork OK", name: "getPostureCheck PrivateNetwork OK",
expectedBody: true, expectedBody: true,
id: privateNetworkCheck.ID, id: privateNetworkCheck.ID,
checkName: privateNetworkCheck.Name, checkName: privateNetworkCheck.Name,
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
}, },
{ {
name: "GetPostureCheck Not Found", name: "getPostureCheck Not Found",
id: "not-exists", id: "not-exists",
expectedStatus: http.StatusNotFound, expectedStatus: http.StatusNotFound,
}, },
@ -189,7 +189,7 @@ func TestGetPostureCheck(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody) req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/posture-checks/{postureCheckId}", p.GetPostureCheck).Methods("GET") router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@ -231,7 +231,7 @@ func TestPostureCheckUpdate(t *testing.T) {
requestType string requestType string
requestPath string requestPath string
requestBody io.Reader requestBody io.Reader
setupHandlerFunc func(handler *PostureChecksHandler) setupHandlerFunc func(handler *postureChecksHandler)
}{ }{
{ {
name: "Create Posture Checks NB version", name: "Create Posture Checks NB version",
@ -286,7 +286,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}, },
}, },
}, },
setupHandlerFunc: func(handler *PostureChecksHandler) { setupHandlerFunc: func(handler *postureChecksHandler) {
handler.geolocationManager = nil handler.geolocationManager = nil
}, },
}, },
@ -427,7 +427,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}`)), }`)),
expectedStatus: http.StatusPreconditionFailed, expectedStatus: http.StatusPreconditionFailed,
expectedBody: false, expectedBody: false,
setupHandlerFunc: func(handler *PostureChecksHandler) { setupHandlerFunc: func(handler *postureChecksHandler) {
handler.geolocationManager = nil handler.geolocationManager = nil
}, },
}, },
@ -614,7 +614,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}, },
}, },
}, },
setupHandlerFunc: func(handler *PostureChecksHandler) { setupHandlerFunc: func(handler *postureChecksHandler) {
handler.geolocationManager = nil handler.geolocationManager = nil
}, },
}, },
@ -677,7 +677,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}`)), }`)),
expectedStatus: http.StatusPreconditionFailed, expectedStatus: http.StatusPreconditionFailed,
expectedBody: false, expectedBody: false,
setupHandlerFunc: func(handler *PostureChecksHandler) { setupHandlerFunc: func(handler *postureChecksHandler) {
handler.geolocationManager = nil handler.geolocationManager = nil
}, },
}, },
@ -842,8 +842,8 @@ func TestPostureCheckUpdate(t *testing.T) {
} }
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/posture-checks", defaultHandler.CreatePostureCheck).Methods("POST") router.HandleFunc("/api/posture-checks", defaultHandler.createPostureCheck).Methods("POST")
router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.UpdatePostureCheck).Methods("PUT") router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.updatePostureCheck).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@ -1,4 +1,4 @@
package http package routes
import ( import (
"encoding/json" "encoding/json"
@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@ -23,15 +24,24 @@ import (
const maxDomains = 32 const maxDomains = 32
const failedToConvertRoute = "failed to convert route to response: %v" const failedToConvertRoute = "failed to convert route to response: %v"
// RoutesHandler is the routes handler of the account // handler is the routes handler of the account
type RoutesHandler struct { type handler struct {
accountManager server.AccountManager accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewRoutesHandler returns a new instance of RoutesHandler handler func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *RoutesHandler { routesHandler := newHandler(accountManager, authCfg)
return &RoutesHandler{ router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS")
router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS")
router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS")
router.HandleFunc("/routes/{routeId}", routesHandler.getRoute).Methods("GET", "OPTIONS")
router.HandleFunc("/routes/{routeId}", routesHandler.deleteRoute).Methods("DELETE", "OPTIONS")
}
// newHandler returns a new instance of routes handler
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
return &handler{
accountManager: accountManager, accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
@ -40,8 +50,8 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro
} }
} }
// GetAllRoutes returns the list of routes for the account // getAllRoutes returns the list of routes for the account
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -67,8 +77,8 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, apiRoutes) util.WriteJSONObject(r.Context(), w, apiRoutes)
} }
// CreateRoute handles route creation request // createRoute handles route creation request
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -139,7 +149,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, routes) util.WriteJSONObject(r.Context(), w, routes)
} }
func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) error { func (h *handler) validateRoute(req api.PostApiRoutesJSONRequestBody) error {
if req.Network != nil && req.Domains != nil { if req.Network != nil && req.Domains != nil {
return status.Errorf(status.InvalidArgument, "only one of 'network' or 'domains' should be provided") return status.Errorf(status.InvalidArgument, "only one of 'network' or 'domains' should be provided")
} }
@ -164,8 +174,8 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
return nil return nil
} }
// UpdateRoute handles update to a route identified by a given ID // updateRoute handles update to a route identified by a given ID
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -257,8 +267,8 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, routes) util.WriteJSONObject(r.Context(), w, routes)
} }
// DeleteRoute handles route deletion request // deleteRoute handles route deletion request
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -278,11 +288,11 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
return return
} }
util.WriteJSONObject(r.Context(), w, emptyObject{}) util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
} }
// GetRoute handles a route Get request identified by ID // getRoute handles a route Get request identified by ID
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {

View File

@ -1,4 +1,4 @@
package http package routes
import ( import (
"bytes" "bytes"
@ -87,8 +87,8 @@ var testingAccount = &server.Account{
}, },
} }
func initRoutesTestData() *RoutesHandler { func initRoutesTestData() *handler {
return &RoutesHandler{ return &handler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) { GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) {
if routeID == existingRouteID { if routeID == existingRouteID {
@ -152,7 +152,7 @@ func initRoutesTestData() *RoutesHandler {
return nil return nil
}, },
GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
//return testingAccount, testingAccount.Users["test_user"], nil // return testingAccount, testingAccount.Users["test_user"], nil
return testingAccount.Id, testingAccount.Users["test_user"].Id, nil return testingAccount.Id, testingAccount.Users["test_user"].Id, nil
}, },
}, },
@ -521,10 +521,10 @@ func TestRoutesHandlers(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/routes/{routeId}", p.GetRoute).Methods("GET") router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET")
router.HandleFunc("/api/routes/{routeId}", p.DeleteRoute).Methods("DELETE") router.HandleFunc("/api/routes/{routeId}", p.deleteRoute).Methods("DELETE")
router.HandleFunc("/api/routes", p.CreateRoute).Methods("POST") router.HandleFunc("/api/routes", p.createRoute).Methods("POST")
router.HandleFunc("/api/routes/{routeId}", p.UpdateRoute).Methods("PUT") router.HandleFunc("/api/routes/{routeId}", p.updateRoute).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@ -1,4 +1,4 @@
package http package setup_keys
import ( import (
"context" "context"
@ -10,20 +10,30 @@ import (
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
// SetupKeysHandler is a handler that returns a list of setup keys of the account // handler is a handler that returns a list of setup keys of the account
type SetupKeysHandler struct { type handler struct {
accountManager server.AccountManager accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewSetupKeysHandler creates a new SetupKeysHandler HTTP handler func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) *SetupKeysHandler { keysHandler := newHandler(accountManager, authCfg)
return &SetupKeysHandler{ router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS")
router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", keysHandler.updateSetupKey).Methods("PUT", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", keysHandler.deleteSetupKey).Methods("DELETE", "OPTIONS")
}
// newHandler creates a new setup key handler
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
return &handler{
accountManager: accountManager, accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
@ -32,8 +42,8 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg)
} }
} }
// CreateSetupKey is a POST requests that creates a new SetupKey // createSetupKey is a POST requests that creates a new SetupKey
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) { func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -89,8 +99,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
util.WriteJSONObject(r.Context(), w, apiSetupKeys) util.WriteJSONObject(r.Context(), w, apiSetupKeys)
} }
// GetSetupKey is a GET request to get a SetupKey by ID // getSetupKey is a GET request to get a SetupKey by ID
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -114,8 +124,8 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
writeSuccess(r.Context(), w, key) writeSuccess(r.Context(), w, key)
} }
// UpdateSetupKey is a PUT request to update server.SetupKey // updateSetupKey is a PUT request to update server.SetupKey
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) { func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -155,8 +165,8 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
writeSuccess(r.Context(), w, newKey) writeSuccess(r.Context(), w, newKey)
} }
// GetAllSetupKeys is a GET request that returns a list of SetupKey // getAllSetupKeys is a GET request that returns a list of SetupKey
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -178,7 +188,7 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques
util.WriteJSONObject(r.Context(), w, apiSetupKeys) util.WriteJSONObject(r.Context(), w, apiSetupKeys)
} }
func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request) { func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -199,7 +209,7 @@ func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request
return return
} }
util.WriteJSONObject(r.Context(), w, emptyObject{}) util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
} }
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) { func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) {

View File

@ -1,4 +1,4 @@
package http package setup_keys
import ( import (
"bytes" "bytes"
@ -26,12 +26,13 @@ const (
newSetupKeyName = "New Setup Key" newSetupKeyName = "New Setup Key"
updatedSetupKeyName = "KKKey" updatedSetupKeyName = "KKKey"
notFoundSetupKeyID = "notFoundSetupKeyID" notFoundSetupKeyID = "notFoundSetupKeyID"
testAccountID = "test_id"
) )
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey, func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey,
user *server.User, user *server.User,
) *SetupKeysHandler { ) *handler {
return &SetupKeysHandler{ return &handler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil return claims.AccountId, claims.UserId, nil
@ -178,11 +179,11 @@ func TestSetupKeysHandlers(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/setup-keys", handler.GetAllSetupKeys).Methods("GET", "OPTIONS") router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys", handler.CreateSetupKey).Methods("POST", "OPTIONS") router.HandleFunc("/api/setup-keys", handler.createSetupKey).Methods("POST", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.GetSetupKey).Methods("GET", "OPTIONS") router.HandleFunc("/api/setup-keys/{keyId}", handler.getSetupKey).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS") router.HandleFunc("/api/setup-keys/{keyId}", handler.updateSetupKey).Methods("PUT", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.DeleteSetupKey).Methods("DELETE", "OPTIONS") router.HandleFunc("/api/setup-keys/{keyId}", handler.deleteSetupKey).Methods("DELETE", "OPTIONS")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@ -1,4 +1,4 @@
package http package users
import ( import (
"encoding/json" "encoding/json"
@ -9,20 +9,29 @@ import (
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
// PATHandler is the nameserver group handler of the account // patHandler is the nameserver group handler of the account
type PATHandler struct { type patHandler struct {
accountManager server.AccountManager accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewPATsHandler creates a new PATHandler HTTP handler func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATHandler { tokenHandler := newPATsHandler(accountManager, authCfg)
return &PATHandler{ router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS")
}
// newPATsHandler creates a new patHandler HTTP handler
func newPATsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *patHandler {
return &patHandler{
accountManager: accountManager, accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
@ -31,8 +40,8 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH
} }
} }
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user // getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -61,8 +70,8 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, patResponse) util.WriteJSONObject(r.Context(), w, patResponse)
} }
// GetToken is HTTP GET handler that returns a personal access token for the given user // getToken is HTTP GET handler that returns a personal access token for the given user
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -92,8 +101,8 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, toPATResponse(pat)) util.WriteJSONObject(r.Context(), w, toPATResponse(pat))
} }
// CreateToken is HTTP POST handler that creates a personal access token for the given user // createToken is HTTP POST handler that creates a personal access token for the given user
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -124,8 +133,8 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, toPATGeneratedResponse(pat)) util.WriteJSONObject(r.Context(), w, toPATGeneratedResponse(pat))
} }
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user // deleteToken is HTTP DELETE handler that deletes a personal access token for the given user
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil { if err != nil {
@ -152,7 +161,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
return return
} }
util.WriteJSONObject(r.Context(), w, emptyObject{}) util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
} }
func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken {

View File

@ -1,4 +1,4 @@
package http package users
import ( import (
"bytes" "bytes"
@ -61,8 +61,8 @@ var testAccount = &server.Account{
}, },
} }
func initPATTestData() *PATHandler { func initPATTestData() *patHandler {
return &PATHandler{ return &patHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
if accountID != existingAccountID { if accountID != existingAccountID {
@ -186,10 +186,10 @@ func TestTokenHandlers(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}/tokens", p.GetAllTokens).Methods("GET") router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.GetToken).Methods("GET") router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens", p.CreateToken).Methods("POST") router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.DeleteToken).Methods("DELETE") router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()

View File

@ -1,4 +1,4 @@
package http package users
import ( import (
"encoding/json" "encoding/json"
@ -9,6 +9,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@ -16,15 +17,25 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
) )
// UsersHandler is a handler that returns users of the account // handler is a handler that returns users of the account
type UsersHandler struct { type handler struct {
accountManager server.AccountManager accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
// NewUsersHandler creates a new UsersHandler HTTP handler func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
func NewUsersHandler(accountManager server.AccountManager, authCfg AuthCfg) *UsersHandler { userHandler := newHandler(accountManager, authCfg)
return &UsersHandler{ router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
addUsersTokensEndpoint(accountManager, authCfg, router)
}
// newHandler creates a new UsersHandler HTTP handler
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
return &handler{
accountManager: accountManager, accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
@ -33,8 +44,8 @@ func NewUsersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Use
} }
} }
// UpdateUser is a PUT requests to update User data // updateUser is a PUT requests to update User data
func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut { if r.Method != http.MethodPut {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return return
@ -94,8 +105,8 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
} }
// DeleteUser is a DELETE request to delete a user // deleteUser is a DELETE request to delete a user
func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete { if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return return
@ -121,11 +132,11 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
return return
} }
util.WriteJSONObject(r.Context(), w, emptyObject{}) util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
} }
// CreateUser creates a User in the system with a status "invited" (effectively this is a user invite). // createUser creates a User in the system with a status "invited" (effectively this is a user invite).
func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return return
@ -175,9 +186,9 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
} }
// GetAllUsers returns a list of users of the account this user belongs to. // getAllUsers returns a list of users of the account this user belongs to.
// It also gathers additional user data (like email and name) from the IDP manager. // It also gathers additional user data (like email and name) from the IDP manager.
func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return return
@ -222,9 +233,9 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, users) util.WriteJSONObject(r.Context(), w, users)
} }
// InviteUser resend invitations to users who haven't activated their accounts, // inviteUser resend invitations to users who haven't activated their accounts,
// prior to the expiration period. // prior to the expiration period.
func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return return
@ -250,7 +261,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
return return
} }
util.WriteJSONObject(r.Context(), w, emptyObject{}) util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
} }
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {

View File

@ -1,4 +1,4 @@
package http package users
import ( import (
"bytes" "bytes"
@ -61,8 +61,8 @@ var usersTestAccount = &server.Account{
}, },
} }
func initUsersTestData() *UsersHandler { func initUsersTestData() *handler {
return &UsersHandler{ return &handler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return usersTestAccount.Id, claims.UserId, nil return usersTestAccount.Id, claims.UserId, nil
@ -147,7 +147,7 @@ func TestGetUsers(t *testing.T) {
requestPath string requestPath string
expectedUserIDs []string expectedUserIDs []string
}{ }{
{name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}}, {name: "getAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}},
{name: "GetOnlyServiceUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=true", expectedStatus: http.StatusOK, expectedUserIDs: []string{serviceUserID}}, {name: "GetOnlyServiceUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=true", expectedStatus: http.StatusOK, expectedUserIDs: []string{serviceUserID}},
{name: "GetOnlyRegularUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=false", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID}}, {name: "GetOnlyRegularUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=false", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID}},
} }
@ -159,7 +159,7 @@ func TestGetUsers(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
userHandler.GetAllUsers(recorder, req) userHandler.getAllUsers(recorder, req)
res := recorder.Result() res := recorder.Result()
defer res.Body.Close() defer res.Body.Close()
@ -265,7 +265,7 @@ func TestUpdateUser(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}", userHandler.UpdateUser).Methods("PUT") router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT")
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
res := recorder.Result() res := recorder.Result()
@ -356,7 +356,7 @@ func TestCreateUser(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
userHandler.CreateUser(rr, req) userHandler.createUser(rr, req)
res := rr.Result() res := rr.Result()
defer res.Body.Close() defer res.Body.Close()
@ -401,7 +401,7 @@ func TestInviteUser(t *testing.T) {
req = mux.SetURLVars(req, tc.requestVars) req = mux.SetURLVars(req, tc.requestVars)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
userHandler.InviteUser(rr, req) userHandler.inviteUser(rr, req)
res := rr.Result() res := rr.Result()
defer res.Body.Close() defer res.Body.Close()
@ -454,7 +454,7 @@ func TestDeleteUser(t *testing.T) {
req = mux.SetURLVars(req, tc.requestVars) req = mux.SetURLVars(req, tc.requestVars)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
userHandler.DeleteUser(rr, req) userHandler.deleteUser(rr, req)
res := rr.Result() res := rr.Result()
defer res.Body.Close() defer res.Body.Close()

View File

@ -14,6 +14,10 @@ import (
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
// EmptyObject is an empty struct used to return empty JSON object
type EmptyObject struct {
}
type ErrorResponse struct { type ErrorResponse struct {
Message string `json:"message"` Message string `json:"message"`
Code int `json:"code"` Code int `json:"code"`

View File

@ -740,7 +740,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// it means that the client has already checked if it needs login and had been through the SSO flow // it means that the client has already checked if it needs login and had been through the SSO flow
// so, we can skip this check and directly proceed with the login // so, we can skip this check and directly proceed with the login
if login.UserID == "" { if login.UserID == "" {
log.Info("Peer needs login")
err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login) err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err

View File

@ -0,0 +1,48 @@
package semaphoregroup
import (
"context"
"sync"
)
// SemaphoreGroup is a custom type that combines sync.WaitGroup and a semaphore.
type SemaphoreGroup struct {
waitGroup sync.WaitGroup
semaphore chan struct{}
}
// NewSemaphoreGroup creates a new SemaphoreGroup with the specified semaphore limit.
func NewSemaphoreGroup(limit int) *SemaphoreGroup {
return &SemaphoreGroup{
semaphore: make(chan struct{}, limit),
}
}
// Add increments the internal WaitGroup counter and acquires a semaphore slot.
func (sg *SemaphoreGroup) Add(ctx context.Context) {
sg.waitGroup.Add(1)
// Acquire semaphore slot
select {
case <-ctx.Done():
return
case sg.semaphore <- struct{}{}:
}
}
// Done decrements the internal WaitGroup counter and releases a semaphore slot.
func (sg *SemaphoreGroup) Done(ctx context.Context) {
sg.waitGroup.Done()
// Release semaphore slot
select {
case <-ctx.Done():
return
case <-sg.semaphore:
}
}
// Wait waits until the internal WaitGroup counter is zero.
func (sg *SemaphoreGroup) Wait() {
sg.waitGroup.Wait()
}

View File

@ -0,0 +1,66 @@
package semaphoregroup
import (
"context"
"testing"
"time"
)
func TestSemaphoreGroup(t *testing.T) {
semGroup := NewSemaphoreGroup(2)
for i := 0; i < 5; i++ {
semGroup.Add(context.Background())
go func(id int) {
defer semGroup.Done(context.Background())
got := len(semGroup.semaphore)
if got == 0 {
t.Errorf("Expected semaphore length > 0 , got 0")
}
time.Sleep(time.Millisecond)
t.Logf("Goroutine %d is running\n", id)
}(i)
}
semGroup.Wait()
want := 0
got := len(semGroup.semaphore)
if got != want {
t.Errorf("Expected semaphore length %d, got %d", want, got)
}
}
func TestSemaphoreGroupContext(t *testing.T) {
semGroup := NewSemaphoreGroup(1)
semGroup.Add(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
t.Cleanup(cancel)
rChan := make(chan struct{})
go func() {
semGroup.Add(ctx)
rChan <- struct{}{}
}()
select {
case <-rChan:
case <-time.NewTimer(2 * time.Second).C:
t.Error("Adding to semaphore group should not block when context is not done")
}
semGroup.Done(context.Background())
ctxDone, cancelDone := context.WithTimeout(context.Background(), 1*time.Second)
t.Cleanup(cancelDone)
go func() {
semGroup.Done(ctxDone)
rChan <- struct{}{}
}()
select {
case <-rChan:
case <-time.NewTimer(2 * time.Second).C:
t.Error("Releasing from semaphore group should not block when context is not done")
}
}