mirror of
https://github.com/netbirdio/netbird.git
synced 2024-12-12 18:00:49 +01:00
Merge branch 'main' into fix/remove-ids-from-policy-creation
This commit is contained in:
commit
3a95966ccc
@ -46,6 +46,7 @@ type ConfigInput struct {
|
|||||||
ManagementURL string
|
ManagementURL string
|
||||||
AdminURL string
|
AdminURL string
|
||||||
ConfigPath string
|
ConfigPath string
|
||||||
|
StateFilePath string
|
||||||
PreSharedKey *string
|
PreSharedKey *string
|
||||||
ServerSSHAllowed *bool
|
ServerSSHAllowed *bool
|
||||||
NATExternalIPs []string
|
NATExternalIPs []string
|
||||||
@ -105,10 +106,10 @@ type Config struct {
|
|||||||
|
|
||||||
// DNSRouteInterval is the interval in which the DNS routes are updated
|
// DNSRouteInterval is the interval in which the DNS routes are updated
|
||||||
DNSRouteInterval time.Duration
|
DNSRouteInterval time.Duration
|
||||||
//Path to a certificate used for mTLS authentication
|
// Path to a certificate used for mTLS authentication
|
||||||
ClientCertPath string
|
ClientCertPath string
|
||||||
|
|
||||||
//Path to corresponding private key of ClientCertPath
|
// Path to corresponding private key of ClientCertPath
|
||||||
ClientCertKeyPath string
|
ClientCertKeyPath string
|
||||||
|
|
||||||
ClientCertKeyPair *tls.Certificate `json:"-"`
|
ClientCertKeyPair *tls.Certificate `json:"-"`
|
||||||
@ -116,7 +117,7 @@ type Config struct {
|
|||||||
|
|
||||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||||
func ReadConfig(configPath string) (*Config, error) {
|
func ReadConfig(configPath string) (*Config, error) {
|
||||||
if configFileIsExists(configPath) {
|
if fileExists(configPath) {
|
||||||
err := util.EnforcePermission(configPath)
|
err := util.EnforcePermission(configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||||
@ -149,7 +150,7 @@ func ReadConfig(configPath string) (*Config, error) {
|
|||||||
|
|
||||||
// UpdateConfig update existing configuration according to input configuration and return with the configuration
|
// UpdateConfig update existing configuration according to input configuration and return with the configuration
|
||||||
func UpdateConfig(input ConfigInput) (*Config, error) {
|
func UpdateConfig(input ConfigInput) (*Config, error) {
|
||||||
if !configFileIsExists(input.ConfigPath) {
|
if !fileExists(input.ConfigPath) {
|
||||||
return nil, status.Errorf(codes.NotFound, "config file doesn't exist")
|
return nil, status.Errorf(codes.NotFound, "config file doesn't exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -158,7 +159,7 @@ func UpdateConfig(input ConfigInput) (*Config, error) {
|
|||||||
|
|
||||||
// UpdateOrCreateConfig reads existing config or generates a new one
|
// UpdateOrCreateConfig reads existing config or generates a new one
|
||||||
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||||
if !configFileIsExists(input.ConfigPath) {
|
if !fileExists(input.ConfigPath) {
|
||||||
log.Infof("generating new config %s", input.ConfigPath)
|
log.Infof("generating new config %s", input.ConfigPath)
|
||||||
cfg, err := createNewConfig(input)
|
cfg, err := createNewConfig(input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -472,11 +473,19 @@ func isPreSharedKeyHidden(preSharedKey *string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func configFileIsExists(path string) bool {
|
func fileExists(path string) bool {
|
||||||
_, err := os.Stat(path)
|
_, err := os.Stat(path)
|
||||||
return !os.IsNotExist(err)
|
return !os.IsNotExist(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func createFile(path string) error {
|
||||||
|
file, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return file.Close()
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
|
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
|
||||||
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
|
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
|
||||||
// The check is performed only for the NetBird's managed version.
|
// The check is performed only for the NetBird's managed version.
|
||||||
|
@ -91,6 +91,7 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
fileDescriptor int32,
|
fileDescriptor int32,
|
||||||
networkChangeListener listener.NetworkChangeListener,
|
networkChangeListener listener.NetworkChangeListener,
|
||||||
dnsManager dns.IosDnsManager,
|
dnsManager dns.IosDnsManager,
|
||||||
|
stateFilePath string,
|
||||||
) error {
|
) error {
|
||||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||||
debug.SetGCPercent(5)
|
debug.SetGCPercent(5)
|
||||||
@ -99,6 +100,7 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
FileDescriptor: fileDescriptor,
|
FileDescriptor: fileDescriptor,
|
||||||
NetworkChangeListener: networkChangeListener,
|
NetworkChangeListener: networkChangeListener,
|
||||||
DnsManager: dnsManager,
|
DnsManager: dnsManager,
|
||||||
|
StateFilePath: stateFilePath,
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil, nil)
|
return c.run(mobileDependency, nil, nil)
|
||||||
}
|
}
|
||||||
|
@ -39,6 +39,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
@ -62,6 +63,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
PeerConnectionTimeoutMax = 45000 // ms
|
PeerConnectionTimeoutMax = 45000 // ms
|
||||||
PeerConnectionTimeoutMin = 30000 // ms
|
PeerConnectionTimeoutMin = 30000 // ms
|
||||||
|
connInitLimit = 200
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||||
@ -177,6 +179,7 @@ type Engine struct {
|
|||||||
// Network map persistence
|
// Network map persistence
|
||||||
persistNetworkMap bool
|
persistNetworkMap bool
|
||||||
latestNetworkMap *mgmProto.NetworkMap
|
latestNetworkMap *mgmProto.NetworkMap
|
||||||
|
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@ -242,6 +245,18 @@ func NewEngineWithProbes(
|
|||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
probes: probes,
|
probes: probes,
|
||||||
checks: checks,
|
checks: checks,
|
||||||
|
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||||
|
}
|
||||||
|
if runtime.GOOS == "ios" {
|
||||||
|
if !fileExists(mobileDep.StateFilePath) {
|
||||||
|
err := createFile(mobileDep.StateFilePath)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create state file: %v", err)
|
||||||
|
// we are not exiting as we can run without the state manager
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
engine.stateManager = statemanager.New(mobileDep.StateFilePath)
|
||||||
}
|
}
|
||||||
if path := statemanager.GetDefaultStatePath(); path != "" {
|
if path := statemanager.GetDefaultStatePath(); path != "" {
|
||||||
engine.stateManager = statemanager.New(path)
|
engine.stateManager = statemanager.New(path)
|
||||||
@ -1040,7 +1055,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher)
|
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher, e.connSemaphore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -19,4 +19,5 @@ type MobileDependency struct {
|
|||||||
// iOS only
|
// iOS only
|
||||||
DnsManager dns.IosDnsManager
|
DnsManager dns.IosDnsManager
|
||||||
FileDescriptor int32
|
FileDescriptor int32
|
||||||
|
StateFilePath string
|
||||||
}
|
}
|
||||||
|
@ -23,6 +23,7 @@ import (
|
|||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConnPriority int
|
type ConnPriority int
|
||||||
@ -104,12 +105,13 @@ type Conn struct {
|
|||||||
wgProxyICE wgproxy.Proxy
|
wgProxyICE wgproxy.Proxy
|
||||||
wgProxyRelay wgproxy.Proxy
|
wgProxyRelay wgproxy.Proxy
|
||||||
|
|
||||||
guard *guard.Guard
|
guard *guard.Guard
|
||||||
|
semaphore *semaphoregroup.SemaphoreGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConn creates a new not opened Conn to the remote peer.
|
// NewConn creates a new not opened Conn to the remote peer.
|
||||||
// To establish a connection run Conn.Open
|
// To establish a connection run Conn.Open
|
||||||
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) {
|
func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) {
|
||||||
allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps)
|
allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to parse allowedIPS: %v", err)
|
log.Errorf("failed to parse allowedIPS: %v", err)
|
||||||
@ -130,6 +132,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
|||||||
allowedIP: allowedIP,
|
allowedIP: allowedIP,
|
||||||
statusRelay: NewAtomicConnStatus(),
|
statusRelay: NewAtomicConnStatus(),
|
||||||
statusICE: NewAtomicConnStatus(),
|
statusICE: NewAtomicConnStatus(),
|
||||||
|
semaphore: semaphore,
|
||||||
}
|
}
|
||||||
|
|
||||||
rFns := WorkerRelayCallbacks{
|
rFns := WorkerRelayCallbacks{
|
||||||
@ -169,6 +172,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
|||||||
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
||||||
// be used.
|
// be used.
|
||||||
func (conn *Conn) Open() {
|
func (conn *Conn) Open() {
|
||||||
|
conn.semaphore.Add(conn.ctx)
|
||||||
conn.log.Debugf("open connection to peer")
|
conn.log.Debugf("open connection to peer")
|
||||||
|
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
@ -191,6 +195,7 @@ func (conn *Conn) Open() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
|
func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
|
||||||
|
defer conn.semaphore.Done(conn.ctx)
|
||||||
conn.waitInitialRandomSleepTime(ctx)
|
conn.waitInitialRandomSleepTime(ctx)
|
||||||
|
|
||||||
err := conn.handshaker.sendOffer()
|
err := conn.handshaker.sendOffer()
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer/ice"
|
"github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
)
|
)
|
||||||
|
|
||||||
var connConf = ConnConfig{
|
var connConf = ConnConfig{
|
||||||
@ -46,7 +47,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
|
|||||||
|
|
||||||
func TestConn_GetKey(t *testing.T) {
|
func TestConn_GetKey(t *testing.T) {
|
||||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||||
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher)
|
conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -58,7 +59,7 @@ func TestConn_GetKey(t *testing.T) {
|
|||||||
|
|
||||||
func TestConn_OnRemoteOffer(t *testing.T) {
|
func TestConn_OnRemoteOffer(t *testing.T) {
|
||||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
|
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -92,7 +93,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
|||||||
|
|
||||||
func TestConn_OnRemoteAnswer(t *testing.T) {
|
func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
|
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -125,7 +126,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
func TestConn_Status(t *testing.T) {
|
func TestConn_Status(t *testing.T) {
|
||||||
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig)
|
||||||
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher)
|
conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -273,3 +273,88 @@ func TestRouteSelector_FilterSelected(t *testing.T) {
|
|||||||
"route2|192.168.0.0/16": {},
|
"route2|192.168.0.0/16": {},
|
||||||
}, filtered)
|
}, filtered)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
|
||||||
|
initialRoutes := []route.NetID{"route1", "route2", "route3"}
|
||||||
|
newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
initialState func(rs *routeselector.RouteSelector) error // Setup initial state
|
||||||
|
wantNewSelected []route.NetID // Expected selected routes after new routes appear
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "New routes with initial selectAll state",
|
||||||
|
initialState: func(rs *routeselector.RouteSelector) error {
|
||||||
|
rs.SelectAllRoutes()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
// When selectAll is true, all routes including new ones should be selected
|
||||||
|
wantNewSelected: []route.NetID{"route1", "route2", "route3", "route4", "route5"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "New routes after specific selection",
|
||||||
|
initialState: func(rs *routeselector.RouteSelector) error {
|
||||||
|
return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, initialRoutes)
|
||||||
|
},
|
||||||
|
// When specific routes were selected, new routes should remain unselected
|
||||||
|
wantNewSelected: []route.NetID{"route1", "route2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "New routes after deselect all",
|
||||||
|
initialState: func(rs *routeselector.RouteSelector) error {
|
||||||
|
rs.DeselectAllRoutes()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
// After deselect all, new routes should remain unselected
|
||||||
|
wantNewSelected: []route.NetID{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "New routes after deselecting specific routes",
|
||||||
|
initialState: func(rs *routeselector.RouteSelector) error {
|
||||||
|
rs.SelectAllRoutes()
|
||||||
|
return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes)
|
||||||
|
},
|
||||||
|
// After deselecting specific routes, new routes should remain unselected
|
||||||
|
wantNewSelected: []route.NetID{"route2", "route3"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "New routes after selecting with append",
|
||||||
|
initialState: func(rs *routeselector.RouteSelector) error {
|
||||||
|
return rs.SelectRoutes([]route.NetID{"route1"}, true, initialRoutes)
|
||||||
|
},
|
||||||
|
// When routes were appended, new routes should remain unselected
|
||||||
|
wantNewSelected: []route.NetID{"route1"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rs := routeselector.NewRouteSelector()
|
||||||
|
|
||||||
|
// Setup initial state
|
||||||
|
err := tt.initialState(rs)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify selection state with new routes
|
||||||
|
for _, id := range newRoutes {
|
||||||
|
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantNewSelected, id),
|
||||||
|
"Route %s selection state incorrect", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Additional verification using FilterSelected
|
||||||
|
routes := route.HAMap{
|
||||||
|
"route1|10.0.0.0/8": {},
|
||||||
|
"route2|192.168.0.0/16": {},
|
||||||
|
"route3|172.16.0.0/12": {},
|
||||||
|
"route4|10.10.0.0/16": {},
|
||||||
|
"route5|192.168.1.0/24": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := rs.FilterSelected(routes)
|
||||||
|
expectedLen := len(tt.wantNewSelected)
|
||||||
|
assert.Equal(t, expectedLen, len(filtered),
|
||||||
|
"FilterSelected returned wrong number of routes, got %d want %d", len(filtered), expectedLen)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -59,6 +59,7 @@ func init() {
|
|||||||
// Client struct manage the life circle of background service
|
// Client struct manage the life circle of background service
|
||||||
type Client struct {
|
type Client struct {
|
||||||
cfgFile string
|
cfgFile string
|
||||||
|
stateFile string
|
||||||
recorder *peer.Status
|
recorder *peer.Status
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
ctxCancelLock *sync.Mutex
|
ctxCancelLock *sync.Mutex
|
||||||
@ -73,9 +74,10 @@ type Client struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
func NewClient(cfgFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
|
func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
|
||||||
return &Client{
|
return &Client{
|
||||||
cfgFile: cfgFile,
|
cfgFile: cfgFile,
|
||||||
|
stateFile: stateFile,
|
||||||
deviceName: deviceName,
|
deviceName: deviceName,
|
||||||
osName: osName,
|
osName: osName,
|
||||||
osVersion: osVersion,
|
osVersion: osVersion,
|
||||||
@ -91,7 +93,8 @@ func (c *Client) Run(fd int32, interfaceName string) error {
|
|||||||
log.Infof("Starting NetBird client")
|
log.Infof("Starting NetBird client")
|
||||||
log.Debugf("Tunnel uses interface: %s", interfaceName)
|
log.Debugf("Tunnel uses interface: %s", interfaceName)
|
||||||
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
|
StateFilePath: c.stateFile,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -124,7 +127,7 @@ func (c *Client) Run(fd int32, interfaceName string) error {
|
|||||||
cfg.WgIface = interfaceName
|
cfg.WgIface = interfaceName
|
||||||
|
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager)
|
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
|
@ -10,9 +10,10 @@ type Preferences struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewPreferences create new Preferences instance
|
// NewPreferences create new Preferences instance
|
||||||
func NewPreferences(configPath string) *Preferences {
|
func NewPreferences(configPath string, stateFilePath string) *Preferences {
|
||||||
ci := internal.ConfigInput{
|
ci := internal.ConfigInput{
|
||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
|
StateFilePath: stateFilePath,
|
||||||
}
|
}
|
||||||
return &Preferences{ci}
|
return &Preferences{ci}
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,8 @@ import (
|
|||||||
|
|
||||||
func TestPreferences_DefaultValues(t *testing.T) {
|
func TestPreferences_DefaultValues(t *testing.T) {
|
||||||
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
||||||
p := NewPreferences(cfgFile)
|
stateFile := filepath.Join(t.TempDir(), "state.json")
|
||||||
|
p := NewPreferences(cfgFile, stateFile)
|
||||||
defaultVar, err := p.GetAdminURL()
|
defaultVar, err := p.GetAdminURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to read default value: %s", err)
|
t.Fatalf("failed to read default value: %s", err)
|
||||||
@ -42,7 +43,8 @@ func TestPreferences_DefaultValues(t *testing.T) {
|
|||||||
func TestPreferences_ReadUncommitedValues(t *testing.T) {
|
func TestPreferences_ReadUncommitedValues(t *testing.T) {
|
||||||
exampleString := "exampleString"
|
exampleString := "exampleString"
|
||||||
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
||||||
p := NewPreferences(cfgFile)
|
stateFile := filepath.Join(t.TempDir(), "state.json")
|
||||||
|
p := NewPreferences(cfgFile, stateFile)
|
||||||
|
|
||||||
p.SetAdminURL(exampleString)
|
p.SetAdminURL(exampleString)
|
||||||
resp, err := p.GetAdminURL()
|
resp, err := p.GetAdminURL()
|
||||||
@ -79,7 +81,8 @@ func TestPreferences_Commit(t *testing.T) {
|
|||||||
exampleURL := "https://myurl.com:443"
|
exampleURL := "https://myurl.com:443"
|
||||||
examplePresharedKey := "topsecret"
|
examplePresharedKey := "topsecret"
|
||||||
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
cfgFile := filepath.Join(t.TempDir(), "netbird.json")
|
||||||
p := NewPreferences(cfgFile)
|
stateFile := filepath.Join(t.TempDir(), "state.json")
|
||||||
|
p := NewPreferences(cfgFile, stateFile)
|
||||||
|
|
||||||
p.SetAdminURL(exampleURL)
|
p.SetAdminURL(exampleURL)
|
||||||
p.SetManagementURL(exampleURL)
|
p.SetManagementURL(exampleURL)
|
||||||
@ -90,7 +93,7 @@ func TestPreferences_Commit(t *testing.T) {
|
|||||||
t.Fatalf("failed to save changes: %s", err)
|
t.Fatalf("failed to save changes: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
p = NewPreferences(cfgFile)
|
p = NewPreferences(cfgFile, stateFile)
|
||||||
resp, err := p.GetAdminURL()
|
resp, err := p.GetAdminURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to read admin url: %s", err)
|
t.Fatalf("failed to read admin url: %s", err)
|
||||||
|
@ -42,6 +42,7 @@ import (
|
|||||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
httpapi "github.com/netbirdio/netbird/management/server/http"
|
httpapi "github.com/netbirdio/netbird/management/server/http"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
"github.com/netbirdio/netbird/management/server/metrics"
|
"github.com/netbirdio/netbird/management/server/metrics"
|
||||||
@ -257,7 +258,7 @@ var (
|
|||||||
return fmt.Errorf("failed creating JWT validator: %v", err)
|
return fmt.Errorf("failed creating JWT validator: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpAPIAuthCfg := httpapi.AuthCfg{
|
httpAPIAuthCfg := configs.AuthCfg{
|
||||||
Issuer: config.HttpConfig.AuthIssuer,
|
Issuer: config.HttpConfig.AuthIssuer,
|
||||||
Audience: config.HttpConfig.AuthAudience,
|
Audience: config.HttpConfig.AuthAudience,
|
||||||
UserIDClaim: config.HttpConfig.AuthUserIDClaim,
|
UserIDClaim: config.HttpConfig.AuthUserIDClaim,
|
||||||
|
9
management/server/http/configs/auth.go
Normal file
9
management/server/http/configs/auth.go
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
package configs
|
||||||
|
|
||||||
|
// AuthCfg contains parameters for authentication middleware
|
||||||
|
type AuthCfg struct {
|
||||||
|
Issuer string
|
||||||
|
Audience string
|
||||||
|
UserIDClaim string
|
||||||
|
KeysLocation string
|
||||||
|
}
|
@ -12,6 +12,16 @@ import (
|
|||||||
|
|
||||||
s "github.com/netbirdio/netbird/management/server"
|
s "github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/handlers/accounts"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/handlers/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/handlers/events"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/handlers/groups"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/handlers/peers"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/handlers/policies"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/handlers/routes"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/handlers/users"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||||
"github.com/netbirdio/netbird/management/server/integrated_validator"
|
"github.com/netbirdio/netbird/management/server/integrated_validator"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
@ -20,27 +30,15 @@ import (
|
|||||||
|
|
||||||
const apiPrefix = "/api"
|
const apiPrefix = "/api"
|
||||||
|
|
||||||
// AuthCfg contains parameters for authentication middleware
|
|
||||||
type AuthCfg struct {
|
|
||||||
Issuer string
|
|
||||||
Audience string
|
|
||||||
UserIDClaim string
|
|
||||||
KeysLocation string
|
|
||||||
}
|
|
||||||
|
|
||||||
type apiHandler struct {
|
type apiHandler struct {
|
||||||
Router *mux.Router
|
Router *mux.Router
|
||||||
AccountManager s.AccountManager
|
AccountManager s.AccountManager
|
||||||
geolocationManager *geolocation.Geolocation
|
geolocationManager *geolocation.Geolocation
|
||||||
AuthCfg AuthCfg
|
AuthCfg configs.AuthCfg
|
||||||
}
|
|
||||||
|
|
||||||
// EmptyObject is an empty struct used to return empty JSON object
|
|
||||||
type emptyObject struct {
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||||
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
|
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
|
||||||
claimsExtractor := jwtclaims.NewClaimsExtractor(
|
claimsExtractor := jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithAudience(authCfg.Audience),
|
jwtclaims.WithAudience(authCfg.Audience),
|
||||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||||
@ -86,122 +84,15 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
|
|||||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
api.addAccountsEndpoint()
|
accounts.AddEndpoints(api.AccountManager, authCfg, router)
|
||||||
api.addPeersEndpoint()
|
peers.AddEndpoints(api.AccountManager, authCfg, router)
|
||||||
api.addUsersEndpoint()
|
users.AddEndpoints(api.AccountManager, authCfg, router)
|
||||||
api.addUsersTokensEndpoint()
|
setup_keys.AddEndpoints(api.AccountManager, authCfg, router)
|
||||||
api.addSetupKeysEndpoint()
|
policies.AddEndpoints(api.AccountManager, api.geolocationManager, authCfg, router)
|
||||||
api.addPoliciesEndpoint()
|
groups.AddEndpoints(api.AccountManager, authCfg, router)
|
||||||
api.addGroupsEndpoint()
|
routes.AddEndpoints(api.AccountManager, authCfg, router)
|
||||||
api.addRoutesEndpoint()
|
dns.AddEndpoints(api.AccountManager, authCfg, router)
|
||||||
api.addDNSNameserversEndpoint()
|
events.AddEndpoints(api.AccountManager, authCfg, router)
|
||||||
api.addDNSSettingEndpoint()
|
|
||||||
api.addEventsEndpoint()
|
|
||||||
api.addPostureCheckEndpoint()
|
|
||||||
api.addLocationsEndpoint()
|
|
||||||
|
|
||||||
return rootRouter, nil
|
return rootRouter, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addAccountsEndpoint() {
|
|
||||||
accountsHandler := NewAccountsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
|
||||||
apiHandler.Router.HandleFunc("/accounts/{accountId}", accountsHandler.UpdateAccount).Methods("PUT", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/accounts/{accountId}", accountsHandler.DeleteAccount).Methods("DELETE", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/accounts", accountsHandler.GetAllAccounts).Methods("GET", "OPTIONS")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addPeersEndpoint() {
|
|
||||||
peersHandler := NewPeersHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
|
||||||
apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
|
|
||||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addUsersEndpoint() {
|
|
||||||
userHandler := NewUsersHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
|
||||||
apiHandler.Router.HandleFunc("/users", userHandler.GetAllUsers).Methods("GET", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/users/{userId}/invite", userHandler.InviteUser).Methods("POST", "OPTIONS")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addUsersTokensEndpoint() {
|
|
||||||
tokenHandler := NewPATsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
|
||||||
apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.GetAllTokens).Methods("GET", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.CreateToken).Methods("POST", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.GetToken).Methods("GET", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.DeleteToken).Methods("DELETE", "OPTIONS")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addSetupKeysEndpoint() {
|
|
||||||
keysHandler := NewSetupKeysHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
|
||||||
apiHandler.Router.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeys).Methods("GET", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.DeleteSetupKey).Methods("DELETE", "OPTIONS")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addPoliciesEndpoint() {
|
|
||||||
policiesHandler := NewPoliciesHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
|
||||||
apiHandler.Router.HandleFunc("/policies", policiesHandler.GetAllPolicies).Methods("GET", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/policies", policiesHandler.CreatePolicy).Methods("POST", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.UpdatePolicy).Methods("PUT", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.GetPolicy).Methods("GET", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.DeletePolicy).Methods("DELETE", "OPTIONS")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addGroupsEndpoint() {
|
|
||||||
groupsHandler := NewGroupsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
|
||||||
apiHandler.Router.HandleFunc("/groups", groupsHandler.GetAllGroups).Methods("GET", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/groups", groupsHandler.CreateGroup).Methods("POST", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.UpdateGroup).Methods("PUT", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.GetGroup).Methods("GET", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.DeleteGroup).Methods("DELETE", "OPTIONS")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addRoutesEndpoint() {
|
|
||||||
routesHandler := NewRoutesHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
|
||||||
apiHandler.Router.HandleFunc("/routes", routesHandler.GetAllRoutes).Methods("GET", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/routes", routesHandler.CreateRoute).Methods("POST", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.UpdateRoute).Methods("PUT", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.GetRoute).Methods("GET", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.DeleteRoute).Methods("DELETE", "OPTIONS")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addDNSNameserversEndpoint() {
|
|
||||||
nameserversHandler := NewNameserversHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
|
||||||
apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.GetAllNameservers).Methods("GET", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.CreateNameserverGroup).Methods("POST", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.UpdateNameserverGroup).Methods("PUT", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.GetNameserverGroup).Methods("GET", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.DeleteNameserverGroup).Methods("DELETE", "OPTIONS")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addDNSSettingEndpoint() {
|
|
||||||
dnsSettingsHandler := NewDNSSettingsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
|
||||||
apiHandler.Router.HandleFunc("/dns/settings", dnsSettingsHandler.GetDNSSettings).Methods("GET", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/dns/settings", dnsSettingsHandler.UpdateDNSSettings).Methods("PUT", "OPTIONS")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addEventsEndpoint() {
|
|
||||||
eventsHandler := NewEventsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
|
|
||||||
apiHandler.Router.HandleFunc("/events", eventsHandler.GetAllEvents).Methods("GET", "OPTIONS")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addPostureCheckEndpoint() {
|
|
||||||
postureCheckHandler := NewPostureChecksHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg)
|
|
||||||
apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.GetAllPostureChecks).Methods("GET", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.CreatePostureCheck).Methods("POST", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.UpdatePostureCheck).Methods("PUT", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.GetPostureCheck).Methods("GET", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.DeletePostureCheck).Methods("DELETE", "OPTIONS")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (apiHandler *apiHandler) addLocationsEndpoint() {
|
|
||||||
locationHandler := NewGeolocationsHandlerHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg)
|
|
||||||
apiHandler.Router.HandleFunc("/locations/countries", locationHandler.GetAllCountries).Methods("GET", "OPTIONS")
|
|
||||||
apiHandler.Router.HandleFunc("/locations/countries/{country}/cities", locationHandler.GetCitiesByCountry).Methods("GET", "OPTIONS")
|
|
||||||
}
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package accounts
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -10,20 +10,28 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AccountsHandler is a handler that handles the server.Account HTTP endpoints
|
// handler is a handler that handles the server.Account HTTP endpoints
|
||||||
type AccountsHandler struct {
|
type handler struct {
|
||||||
accountManager server.AccountManager
|
accountManager server.AccountManager
|
||||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAccountsHandler creates a new AccountsHandler HTTP handler
|
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||||
func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *AccountsHandler {
|
accountsHandler := newHandler(accountManager, authCfg)
|
||||||
return &AccountsHandler{
|
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
|
||||||
|
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
|
||||||
|
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
// newHandler creates a new handler HTTP handler
|
||||||
|
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
|
||||||
|
return &handler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithAudience(authCfg.Audience),
|
jwtclaims.WithAudience(authCfg.Audience),
|
||||||
@ -32,8 +40,8 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
|
// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
|
||||||
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -51,8 +59,8 @@ func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request)
|
|||||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
||||||
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
_, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
_, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -111,8 +119,8 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
|
|||||||
util.WriteJSONObject(r.Context(), w, &resp)
|
util.WriteJSONObject(r.Context(), w, &resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteAccount is a HTTP DELETE handler to delete an account
|
// deleteAccount is a HTTP DELETE handler to delete an account
|
||||||
func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
vars := mux.Vars(r)
|
vars := mux.Vars(r)
|
||||||
targetAccountID := vars["accountId"]
|
targetAccountID := vars["accountId"]
|
||||||
@ -127,7 +135,7 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func toAccountResponse(accountID string, settings *server.Settings) *api.Account {
|
func toAccountResponse(accountID string, settings *server.Settings) *api.Account {
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package accounts
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -20,8 +20,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
|
func initAccountsTestData(account *server.Account, admin *server.User) *handler {
|
||||||
return &AccountsHandler{
|
return &handler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return account.Id, admin.Id, nil
|
return account.Id, admin.Id, nil
|
||||||
@ -89,7 +89,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
requestBody io.Reader
|
requestBody io.Reader
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "GetAllAccounts OK",
|
name: "getAllAccounts OK",
|
||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
requestType: http.MethodGet,
|
requestType: http.MethodGet,
|
||||||
requestPath: "/api/accounts",
|
requestPath: "/api/accounts",
|
||||||
@ -189,8 +189,8 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/accounts", handler.GetAllAccounts).Methods("GET")
|
router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET")
|
||||||
router.HandleFunc("/api/accounts/{accountId}", handler.UpdateAccount).Methods("PUT")
|
router.HandleFunc("/api/accounts/{accountId}", handler.updateAccount).Methods("PUT")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
@ -1,26 +1,39 @@
|
|||||||
package http
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DNSSettingsHandler is a handler that returns the DNS settings of the account
|
// dnsSettingsHandler is a handler that returns the DNS settings of the account
|
||||||
type DNSSettingsHandler struct {
|
type dnsSettingsHandler struct {
|
||||||
accountManager server.AccountManager
|
accountManager server.AccountManager
|
||||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSSettingsHandler returns a new instance of DNSSettingsHandler handler
|
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||||
func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg) *DNSSettingsHandler {
|
addDNSSettingEndpoint(accountManager, authCfg, router)
|
||||||
return &DNSSettingsHandler{
|
addDNSNameserversEndpoint(accountManager, authCfg, router)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addDNSSettingEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||||
|
dnsSettingsHandler := newDNSSettingsHandler(accountManager, authCfg)
|
||||||
|
router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler
|
||||||
|
func newDNSSettingsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *dnsSettingsHandler {
|
||||||
|
return &dnsSettingsHandler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithAudience(authCfg.Audience),
|
jwtclaims.WithAudience(authCfg.Audience),
|
||||||
@ -29,8 +42,8 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDNSSettings returns the DNS settings for the account
|
// getDNSSettings returns the DNS settings for the account
|
||||||
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
|
func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -52,8 +65,8 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
|
|||||||
util.WriteJSONObject(r.Context(), w, apiDNSSettings)
|
util.WriteJSONObject(r.Context(), w, apiDNSSettings)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateDNSSettings handles update to DNS settings of an account
|
// updateDNSSettings handles update to DNS settings of an account
|
||||||
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
|
func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -40,8 +40,8 @@ var testingDNSSettingsAccount = &server.Account{
|
|||||||
DNSSettings: baseExistingDNSSettings,
|
DNSSettings: baseExistingDNSSettings,
|
||||||
}
|
}
|
||||||
|
|
||||||
func initDNSSettingsTestData() *DNSSettingsHandler {
|
func initDNSSettingsTestData() *dnsSettingsHandler {
|
||||||
return &DNSSettingsHandler{
|
return &dnsSettingsHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) {
|
GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) {
|
||||||
return &testingDNSSettingsAccount.DNSSettings, nil
|
return &testingDNSSettingsAccount.DNSSettings, nil
|
||||||
@ -120,8 +120,8 @@ func TestDNSSettingsHandlers(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/dns/settings", p.GetDNSSettings).Methods("GET")
|
router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET")
|
||||||
router.HandleFunc("/api/dns/settings", p.UpdateDNSSettings).Methods("PUT")
|
router.HandleFunc("/api/dns/settings", p.updateDNSSettings).Methods("PUT")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -11,20 +11,30 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NameserversHandler is the nameserver group handler of the account
|
// nameserversHandler is the nameserver group handler of the account
|
||||||
type NameserversHandler struct {
|
type nameserversHandler struct {
|
||||||
accountManager server.AccountManager
|
accountManager server.AccountManager
|
||||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNameserversHandler returns a new instance of NameserversHandler handler
|
func addDNSNameserversEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||||
func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg) *NameserversHandler {
|
nameserversHandler := newNameserversHandler(accountManager, authCfg)
|
||||||
return &NameserversHandler{
|
router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS")
|
||||||
|
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.getNameserverGroup).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.deleteNameserverGroup).Methods("DELETE", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
// newNameserversHandler returns a new instance of nameserversHandler handler
|
||||||
|
func newNameserversHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *nameserversHandler {
|
||||||
|
return &nameserversHandler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithAudience(authCfg.Audience),
|
jwtclaims.WithAudience(authCfg.Audience),
|
||||||
@ -33,8 +43,8 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllNameservers returns the list of nameserver groups for the account
|
// getAllNameservers returns the list of nameserver groups for the account
|
||||||
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
|
func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -57,8 +67,8 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
|
|||||||
util.WriteJSONObject(r.Context(), w, apiNameservers)
|
util.WriteJSONObject(r.Context(), w, apiNameservers)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateNameserverGroup handles nameserver group creation request
|
// createNameserverGroup handles nameserver group creation request
|
||||||
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -90,8 +100,8 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
util.WriteJSONObject(r.Context(), w, &resp)
|
util.WriteJSONObject(r.Context(), w, &resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID
|
// updateNameserverGroup handles update to a nameserver group identified by a given ID
|
||||||
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -141,8 +151,8 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
util.WriteJSONObject(r.Context(), w, &resp)
|
util.WriteJSONObject(r.Context(), w, &resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteNameserverGroup handles nameserver group deletion request
|
// deleteNameserverGroup handles nameserver group deletion request
|
||||||
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -162,11 +172,11 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNameserverGroup handles a nameserver group Get request identified by ID
|
// getNameserverGroup handles a nameserver group Get request identified by ID
|
||||||
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -50,8 +50,8 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
func initNameserversTestData() *NameserversHandler {
|
func initNameserversTestData() *nameserversHandler {
|
||||||
return &NameserversHandler{
|
return &nameserversHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||||
if nsGroupID == existingNSGroupID {
|
if nsGroupID == existingNSGroupID {
|
||||||
@ -206,10 +206,10 @@ func TestNameserversHandlers(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.GetNameserverGroup).Methods("GET")
|
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET")
|
||||||
router.HandleFunc("/api/dns/nameservers", p.CreateNameserverGroup).Methods("POST")
|
router.HandleFunc("/api/dns/nameservers", p.createNameserverGroup).Methods("POST")
|
||||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.DeleteNameserverGroup).Methods("DELETE")
|
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.deleteNameserverGroup).Methods("DELETE")
|
||||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.UpdateNameserverGroup).Methods("PUT")
|
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.updateNameserverGroup).Methods("PUT")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
@ -1,28 +1,35 @@
|
|||||||
package http
|
package events
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
)
|
)
|
||||||
|
|
||||||
// EventsHandler HTTP handler
|
// handler HTTP handler
|
||||||
type EventsHandler struct {
|
type handler struct {
|
||||||
accountManager server.AccountManager
|
accountManager server.AccountManager
|
||||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewEventsHandler creates a new EventsHandler HTTP handler
|
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||||
func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *EventsHandler {
|
eventsHandler := newHandler(accountManager, authCfg)
|
||||||
return &EventsHandler{
|
router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
// newHandler creates a new events handler
|
||||||
|
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
|
||||||
|
return &handler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithAudience(authCfg.Audience),
|
jwtclaims.WithAudience(authCfg.Audience),
|
||||||
@ -31,8 +38,8 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllEvents list of the given account
|
// getAllEvents list of the given account
|
||||||
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -60,7 +67,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, events)
|
util.WriteJSONObject(r.Context(), w, events)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *EventsHandler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error {
|
func (h *handler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error {
|
||||||
// build email, name maps based on users
|
// build email, name maps based on users
|
||||||
userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId)
|
userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId)
|
||||||
if err != nil {
|
if err != nil {
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package events
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -20,8 +20,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
)
|
)
|
||||||
|
|
||||||
func initEventsTestData(account string, events ...*activity.Event) *EventsHandler {
|
func initEventsTestData(account string, events ...*activity.Event) *handler {
|
||||||
return &EventsHandler{
|
return &handler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
|
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
|
||||||
if accountID == account {
|
if accountID == account {
|
||||||
@ -183,7 +183,7 @@ func TestEvents_GetEvents(t *testing.T) {
|
|||||||
requestBody io.Reader
|
requestBody io.Reader
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "GetAllEvents OK",
|
name: "getAllEvents OK",
|
||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
requestType: http.MethodGet,
|
requestType: http.MethodGet,
|
||||||
requestPath: "/api/events/",
|
requestPath: "/api/events/",
|
||||||
@ -201,7 +201,7 @@ func TestEvents_GetEvents(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/events/", handler.GetAllEvents).Methods("GET")
|
router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
@ -1,13 +1,15 @@
|
|||||||
package http
|
package groups
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
@ -16,15 +18,24 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GroupsHandler is a handler that returns groups of the account
|
// handler is a handler that returns groups of the account
|
||||||
type GroupsHandler struct {
|
type handler struct {
|
||||||
accountManager server.AccountManager
|
accountManager server.AccountManager
|
||||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGroupsHandler creates a new GroupsHandler HTTP handler
|
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||||
func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *GroupsHandler {
|
groupsHandler := newHandler(accountManager, authCfg)
|
||||||
return &GroupsHandler{
|
router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS")
|
||||||
|
router.HandleFunc("/groups/{groupId}", groupsHandler.getGroup).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/groups/{groupId}", groupsHandler.deleteGroup).Methods("DELETE", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
// newHandler creates a new groups handler
|
||||||
|
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
|
||||||
|
return &handler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithAudience(authCfg.Audience),
|
jwtclaims.WithAudience(authCfg.Audience),
|
||||||
@ -33,8 +44,8 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllGroups list for the account
|
// getAllGroups list for the account
|
||||||
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -63,8 +74,8 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, groupsResponse)
|
util.WriteJSONObject(r.Context(), w, groupsResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateGroup handles update to a group identified by a given ID
|
// updateGroup handles update to a group identified by a given ID
|
||||||
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -141,8 +152,8 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
|
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateGroup handles group creation request
|
// createGroup handles group creation request
|
||||||
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -189,8 +200,8 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
|
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteGroup handles group deletion request
|
// deleteGroup handles group deletion request
|
||||||
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -215,11 +226,11 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGroup returns a group
|
// getGroup returns a group
|
||||||
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package groups
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -31,8 +31,8 @@ var TestPeers = map[string]*nbpeer.Peer{
|
|||||||
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
|
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
|
||||||
}
|
}
|
||||||
|
|
||||||
func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler {
|
func initGroupTestData(initGroups ...*nbgroup.Group) *handler {
|
||||||
return &GroupsHandler{
|
return &handler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
|
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
|
||||||
if !strings.HasPrefix(group.ID, "id-") {
|
if !strings.HasPrefix(group.ID, "id-") {
|
||||||
@ -106,14 +106,14 @@ func TestGetGroup(t *testing.T) {
|
|||||||
requestBody io.Reader
|
requestBody io.Reader
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "GetGroup OK",
|
name: "getGroup OK",
|
||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
requestType: http.MethodGet,
|
requestType: http.MethodGet,
|
||||||
requestPath: "/api/groups/idofthegroup",
|
requestPath: "/api/groups/idofthegroup",
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "GetGroup not found",
|
name: "getGroup not found",
|
||||||
requestType: http.MethodGet,
|
requestType: http.MethodGet,
|
||||||
requestPath: "/api/groups/notexists",
|
requestPath: "/api/groups/notexists",
|
||||||
expectedStatus: http.StatusNotFound,
|
expectedStatus: http.StatusNotFound,
|
||||||
@ -133,7 +133,7 @@ func TestGetGroup(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/groups/{groupId}", p.GetGroup).Methods("GET")
|
router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
||||||
@ -254,8 +254,8 @@ func TestWriteGroup(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/groups", p.CreateGroup).Methods("POST")
|
router.HandleFunc("/api/groups", p.createGroup).Methods("POST")
|
||||||
router.HandleFunc("/api/groups/{groupId}", p.UpdateGroup).Methods("PUT")
|
router.HandleFunc("/api/groups/{groupId}", p.updateGroup).Methods("PUT")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
||||||
@ -331,7 +331,7 @@ func TestDeleteGroup(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/groups/{groupId}", p.DeleteGroup).Methods("DELETE")
|
router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package peers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -12,21 +12,30 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PeersHandler is a handler that returns peers of the account
|
// Handler is a handler that returns peers of the account
|
||||||
type PeersHandler struct {
|
type Handler struct {
|
||||||
accountManager server.AccountManager
|
accountManager server.AccountManager
|
||||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPeersHandler creates a new PeersHandler HTTP handler
|
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||||
func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *PeersHandler {
|
peersHandler := NewHandler(accountManager, authCfg)
|
||||||
return &PeersHandler{
|
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
|
||||||
|
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||||
|
router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandler creates a new peers Handler
|
||||||
|
func NewHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *Handler {
|
||||||
|
return &Handler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithAudience(authCfg.Audience),
|
jwtclaims.WithAudience(authCfg.Audience),
|
||||||
@ -35,7 +44,7 @@ func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Pee
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) {
|
func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||||
peerToReturn := peer.Copy()
|
peerToReturn := peer.Copy()
|
||||||
if peer.Status.Connected {
|
if peer.Status.Connected {
|
||||||
// Although we have online status in store we do not yet have an updated channel so have to show it as disconnected
|
// Although we have online status in store we do not yet have an updated channel so have to show it as disconnected
|
||||||
@ -48,7 +57,7 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
|||||||
return peerToReturn, nil
|
return peerToReturn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) {
|
func (h *Handler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) {
|
||||||
peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID)
|
peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(ctx, err, w)
|
util.WriteError(ctx, err, w)
|
||||||
@ -75,7 +84,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
|
|||||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
|
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||||
req := &api.PeerRequest{}
|
req := &api.PeerRequest{}
|
||||||
err := json.NewDecoder(r.Body).Decode(&req)
|
err := json.NewDecoder(r.Body).Decode(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -120,18 +129,18 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
|
|||||||
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid))
|
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
|
func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
|
||||||
err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID)
|
err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to delete peer: %v", err)
|
log.WithContext(ctx).Errorf("failed to delete peer: %v", err)
|
||||||
util.WriteError(ctx, err, w)
|
util.WriteError(ctx, err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
util.WriteJSONObject(ctx, w, emptyObject{})
|
util.WriteJSONObject(ctx, w, util.EmptyObject{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
|
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
|
||||||
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -168,7 +177,7 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAllPeers returns a list of all peers associated with a provided account
|
// GetAllPeers returns a list of all peers associated with a provided account
|
||||||
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -219,7 +228,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, respBody)
|
util.WriteJSONObject(r.Context(), w, respBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
|
func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
|
||||||
for _, peer := range respBody {
|
for _, peer := range respBody {
|
||||||
_, ok := approvedPeersMap[peer.Id]
|
_, ok := approvedPeersMap[peer.Id]
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -229,7 +238,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
|
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
|
||||||
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package peers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -38,8 +38,8 @@ const (
|
|||||||
userIDKey ctxKey = "user_id"
|
userIDKey ctxKey = "user_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
|
func initTestMetaData(peers ...*nbpeer.Peer) *Handler {
|
||||||
return &PeersHandler{
|
return &Handler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
|
||||||
var p *nbpeer.Peer
|
var p *nbpeer.Peer
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package policies
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -11,9 +11,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
@ -21,12 +21,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
|
func initGeolocationTestData(t *testing.T) *geolocationsHandler {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
mmdbPath = "../testdata/GeoLite2-City_20240305.mmdb"
|
mmdbPath = "../../../testdata/GeoLite2-City_20240305.mmdb"
|
||||||
geonamesdbPath = "../testdata/geonames_20240305.db"
|
geonamesdbPath = "../../../testdata/geonames_20240305.db"
|
||||||
)
|
)
|
||||||
|
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
@ -41,7 +41,7 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
t.Cleanup(func() { _ = geo.Stop() })
|
t.Cleanup(func() { _ = geo.Stop() })
|
||||||
|
|
||||||
return &GeolocationsHandler{
|
return &geolocationsHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return claims.AccountId, claims.UserId, nil
|
return claims.AccountId, claims.UserId, nil
|
||||||
@ -114,7 +114,7 @@ func TestGetCitiesByCountry(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.GetCitiesByCountry).Methods("GET")
|
router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
||||||
@ -202,7 +202,7 @@ func TestGetAllCountries(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/locations/countries", geolocationHandler.GetAllCountries).Methods("GET")
|
router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package policies
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
@ -18,16 +19,22 @@ var (
|
|||||||
countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$")
|
countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$")
|
||||||
)
|
)
|
||||||
|
|
||||||
// GeolocationsHandler is a handler that returns locations.
|
// geolocationsHandler is a handler that returns locations.
|
||||||
type GeolocationsHandler struct {
|
type geolocationsHandler struct {
|
||||||
accountManager server.AccountManager
|
accountManager server.AccountManager
|
||||||
geolocationManager *geolocation.Geolocation
|
geolocationManager *geolocation.Geolocation
|
||||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGeolocationsHandlerHandler creates a new Geolocations handler
|
func addLocationsEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
|
||||||
func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *GeolocationsHandler {
|
locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg)
|
||||||
return &GeolocationsHandler{
|
router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
// newGeolocationsHandlerHandler creates a new Geolocations handler
|
||||||
|
func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler {
|
||||||
|
return &geolocationsHandler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
geolocationManager: geolocationManager,
|
geolocationManager: geolocationManager,
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
@ -37,8 +44,8 @@ func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geoloca
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllCountries retrieves a list of all countries
|
// getAllCountries retrieves a list of all countries
|
||||||
func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Request) {
|
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := l.authenticateUser(r); err != nil {
|
if err := l.authenticateUser(r); err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -63,8 +70,8 @@ func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Req
|
|||||||
util.WriteJSONObject(r.Context(), w, countries)
|
util.WriteJSONObject(r.Context(), w, countries)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCitiesByCountry retrieves a list of cities based on the given country code
|
// getCitiesByCountry retrieves a list of cities based on the given country code
|
||||||
func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.Request) {
|
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := l.authenticateUser(r); err != nil {
|
if err := l.authenticateUser(r); err != nil {
|
||||||
util.WriteError(r.Context(), err, w)
|
util.WriteError(r.Context(), err, w)
|
||||||
return
|
return
|
||||||
@ -96,7 +103,7 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
|
|||||||
util.WriteJSONObject(r.Context(), w, cities)
|
util.WriteJSONObject(r.Context(), w, cities)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
|
func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
|
||||||
claims := l.claimsExtractor.FromRequestContext(r)
|
claims := l.claimsExtractor.FromRequestContext(r)
|
||||||
_, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
_, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package policies
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -8,22 +8,34 @@ import (
|
|||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Policies is a handler that returns policy of the account
|
// handler is a handler that returns policy of the account
|
||||||
type Policies struct {
|
type handler struct {
|
||||||
accountManager server.AccountManager
|
accountManager server.AccountManager
|
||||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPoliciesHandler creates a new Policies handler
|
func AddEndpoints(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
|
||||||
func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Policies {
|
policiesHandler := newHandler(accountManager, authCfg)
|
||||||
return &Policies{
|
router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS")
|
||||||
|
router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS")
|
||||||
|
addPostureCheckEndpoint(accountManager, locationManager, authCfg, router)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newHandler creates a new policies handler
|
||||||
|
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
|
||||||
|
return &handler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithAudience(authCfg.Audience),
|
jwtclaims.WithAudience(authCfg.Audience),
|
||||||
@ -32,8 +44,8 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllPolicies list for the account
|
// getAllPolicies list for the account
|
||||||
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -66,8 +78,8 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, policies)
|
util.WriteJSONObject(r.Context(), w, policies)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePolicy handles update to a policy identified by a given ID
|
// updatePolicy handles update to a policy identified by a given ID
|
||||||
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -91,8 +103,8 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
h.savePolicy(w, r, accountID, userID, policyID)
|
h.savePolicy(w, r, accountID, userID, policyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreatePolicy handles policy creation request
|
// createPolicy handles policy creation request
|
||||||
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -104,7 +116,7 @@ func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// savePolicy handles policy creation and update
|
// savePolicy handles policy creation and update
|
||||||
func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) {
|
func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) {
|
||||||
var req api.PutApiPoliciesPolicyIdJSONRequestBody
|
var req api.PutApiPoliciesPolicyIdJSONRequestBody
|
||||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||||
@ -251,8 +263,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
|
|||||||
util.WriteJSONObject(r.Context(), w, resp)
|
util.WriteJSONObject(r.Context(), w, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePolicy handles policy deletion request
|
// deletePolicy handles policy deletion request
|
||||||
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -272,11 +284,11 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPolicy handles a group Get request identified by ID
|
// getPolicy handles a group Get request identified by ID
|
||||||
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package policies
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -24,12 +24,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||||
)
|
)
|
||||||
|
|
||||||
func initPoliciesTestData(policies ...*server.Policy) *Policies {
|
func initPoliciesTestData(policies ...*server.Policy) *handler {
|
||||||
testPolicies := make(map[string]*server.Policy, len(policies))
|
testPolicies := make(map[string]*server.Policy, len(policies))
|
||||||
for _, policy := range policies {
|
for _, policy := range policies {
|
||||||
testPolicies[policy.ID] = policy
|
testPolicies[policy.ID] = policy
|
||||||
}
|
}
|
||||||
return &Policies{
|
return &handler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) {
|
GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) {
|
||||||
policy, ok := testPolicies[policyID]
|
policy, ok := testPolicies[policyID]
|
||||||
@ -91,14 +91,14 @@ func TestPoliciesGetPolicy(t *testing.T) {
|
|||||||
requestBody io.Reader
|
requestBody io.Reader
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "GetPolicy OK",
|
name: "getPolicy OK",
|
||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
requestType: http.MethodGet,
|
requestType: http.MethodGet,
|
||||||
requestPath: "/api/policies/idofthepolicy",
|
requestPath: "/api/policies/idofthepolicy",
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "GetPolicy not found",
|
name: "getPolicy not found",
|
||||||
requestType: http.MethodGet,
|
requestType: http.MethodGet,
|
||||||
requestPath: "/api/policies/notexists",
|
requestPath: "/api/policies/notexists",
|
||||||
expectedStatus: http.StatusNotFound,
|
expectedStatus: http.StatusNotFound,
|
||||||
@ -121,7 +121,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/policies/{policyId}", p.GetPolicy).Methods("GET")
|
router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
||||||
@ -272,8 +272,8 @@ func TestPoliciesWritePolicy(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/policies", p.CreatePolicy).Methods("POST")
|
router.HandleFunc("/api/policies", p.createPolicy).Methods("POST")
|
||||||
router.HandleFunc("/api/policies/{policyId}", p.UpdatePolicy).Methods("PUT")
|
router.HandleFunc("/api/policies/{policyId}", p.updatePolicy).Methods("PUT")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package policies
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -9,22 +9,33 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PostureChecksHandler is a handler that returns posture checks of the account.
|
// postureChecksHandler is a handler that returns posture checks of the account.
|
||||||
type PostureChecksHandler struct {
|
type postureChecksHandler struct {
|
||||||
accountManager server.AccountManager
|
accountManager server.AccountManager
|
||||||
geolocationManager *geolocation.Geolocation
|
geolocationManager *geolocation.Geolocation
|
||||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPostureChecksHandler creates a new PostureChecks handler
|
func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
|
||||||
func NewPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *PostureChecksHandler {
|
postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg)
|
||||||
return &PostureChecksHandler{
|
router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS")
|
||||||
|
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS")
|
||||||
|
addLocationsEndpoint(accountManager, locationManager, authCfg, router)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newPostureChecksHandler creates a new PostureChecks handler
|
||||||
|
func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler {
|
||||||
|
return &postureChecksHandler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
geolocationManager: geolocationManager,
|
geolocationManager: geolocationManager,
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
@ -34,8 +45,8 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllPostureChecks list for the account
|
// getAllPostureChecks list for the account
|
||||||
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
|
func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -57,8 +68,8 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt
|
|||||||
util.WriteJSONObject(r.Context(), w, postureChecks)
|
util.WriteJSONObject(r.Context(), w, postureChecks)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePostureCheck handles update to a posture check identified by a given ID
|
// updatePostureCheck handles update to a posture check identified by a given ID
|
||||||
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -82,8 +93,8 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
|
|||||||
p.savePostureChecks(w, r, accountID, userID, postureChecksID)
|
p.savePostureChecks(w, r, accountID, userID, postureChecksID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreatePostureCheck handles posture check creation request
|
// createPostureCheck handles posture check creation request
|
||||||
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -94,8 +105,8 @@ func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http
|
|||||||
p.savePostureChecks(w, r, accountID, userID, "")
|
p.savePostureChecks(w, r, accountID, userID, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPostureCheck handles a posture check Get request identified by ID
|
// getPostureCheck handles a posture check Get request identified by ID
|
||||||
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
|
func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -119,8 +130,8 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
|
|||||||
util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse())
|
util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse())
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePostureCheck handles posture check deletion request
|
// deletePostureCheck handles posture check deletion request
|
||||||
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
|
func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := p.claimsExtractor.FromRequestContext(r)
|
claims := p.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -140,11 +151,11 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// savePostureChecks handles posture checks create and update
|
// savePostureChecks handles posture checks create and update
|
||||||
func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) {
|
func (p *postureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) {
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
req api.PostureCheckUpdate
|
req api.PostureCheckUpdate
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package policies
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -25,13 +25,13 @@ import (
|
|||||||
var berlin = "Berlin"
|
var berlin = "Berlin"
|
||||||
var losAngeles = "Los Angeles"
|
var losAngeles = "Los Angeles"
|
||||||
|
|
||||||
func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksHandler {
|
func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksHandler {
|
||||||
testPostureChecks := make(map[string]*posture.Checks, len(postureChecks))
|
testPostureChecks := make(map[string]*posture.Checks, len(postureChecks))
|
||||||
for _, postureCheck := range postureChecks {
|
for _, postureCheck := range postureChecks {
|
||||||
testPostureChecks[postureCheck.ID] = postureCheck
|
testPostureChecks[postureCheck.ID] = postureCheck
|
||||||
}
|
}
|
||||||
|
|
||||||
return &PostureChecksHandler{
|
return &postureChecksHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetPostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
GetPostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
|
||||||
p, ok := testPostureChecks[postureChecksID]
|
p, ok := testPostureChecks[postureChecksID]
|
||||||
@ -147,35 +147,35 @@ func TestGetPostureCheck(t *testing.T) {
|
|||||||
requestBody io.Reader
|
requestBody io.Reader
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "GetPostureCheck NBVersion OK",
|
name: "getPostureCheck NBVersion OK",
|
||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
id: postureCheck.ID,
|
id: postureCheck.ID,
|
||||||
checkName: postureCheck.Name,
|
checkName: postureCheck.Name,
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "GetPostureCheck OSVersion OK",
|
name: "getPostureCheck OSVersion OK",
|
||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
id: osPostureCheck.ID,
|
id: osPostureCheck.ID,
|
||||||
checkName: osPostureCheck.Name,
|
checkName: osPostureCheck.Name,
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "GetPostureCheck GeoLocation OK",
|
name: "getPostureCheck GeoLocation OK",
|
||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
id: geoPostureCheck.ID,
|
id: geoPostureCheck.ID,
|
||||||
checkName: geoPostureCheck.Name,
|
checkName: geoPostureCheck.Name,
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "GetPostureCheck PrivateNetwork OK",
|
name: "getPostureCheck PrivateNetwork OK",
|
||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
id: privateNetworkCheck.ID,
|
id: privateNetworkCheck.ID,
|
||||||
checkName: privateNetworkCheck.Name,
|
checkName: privateNetworkCheck.Name,
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "GetPostureCheck Not Found",
|
name: "getPostureCheck Not Found",
|
||||||
id: "not-exists",
|
id: "not-exists",
|
||||||
expectedStatus: http.StatusNotFound,
|
expectedStatus: http.StatusNotFound,
|
||||||
},
|
},
|
||||||
@ -189,7 +189,7 @@ func TestGetPostureCheck(t *testing.T) {
|
|||||||
req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody)
|
req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/posture-checks/{postureCheckId}", p.GetPostureCheck).Methods("GET")
|
router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
||||||
@ -231,7 +231,7 @@ func TestPostureCheckUpdate(t *testing.T) {
|
|||||||
requestType string
|
requestType string
|
||||||
requestPath string
|
requestPath string
|
||||||
requestBody io.Reader
|
requestBody io.Reader
|
||||||
setupHandlerFunc func(handler *PostureChecksHandler)
|
setupHandlerFunc func(handler *postureChecksHandler)
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Create Posture Checks NB version",
|
name: "Create Posture Checks NB version",
|
||||||
@ -286,7 +286,7 @@ func TestPostureCheckUpdate(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
setupHandlerFunc: func(handler *PostureChecksHandler) {
|
setupHandlerFunc: func(handler *postureChecksHandler) {
|
||||||
handler.geolocationManager = nil
|
handler.geolocationManager = nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -427,7 +427,7 @@ func TestPostureCheckUpdate(t *testing.T) {
|
|||||||
}`)),
|
}`)),
|
||||||
expectedStatus: http.StatusPreconditionFailed,
|
expectedStatus: http.StatusPreconditionFailed,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
setupHandlerFunc: func(handler *PostureChecksHandler) {
|
setupHandlerFunc: func(handler *postureChecksHandler) {
|
||||||
handler.geolocationManager = nil
|
handler.geolocationManager = nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -614,7 +614,7 @@ func TestPostureCheckUpdate(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
setupHandlerFunc: func(handler *PostureChecksHandler) {
|
setupHandlerFunc: func(handler *postureChecksHandler) {
|
||||||
handler.geolocationManager = nil
|
handler.geolocationManager = nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -677,7 +677,7 @@ func TestPostureCheckUpdate(t *testing.T) {
|
|||||||
}`)),
|
}`)),
|
||||||
expectedStatus: http.StatusPreconditionFailed,
|
expectedStatus: http.StatusPreconditionFailed,
|
||||||
expectedBody: false,
|
expectedBody: false,
|
||||||
setupHandlerFunc: func(handler *PostureChecksHandler) {
|
setupHandlerFunc: func(handler *postureChecksHandler) {
|
||||||
handler.geolocationManager = nil
|
handler.geolocationManager = nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -842,8 +842,8 @@ func TestPostureCheckUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/posture-checks", defaultHandler.CreatePostureCheck).Methods("POST")
|
router.HandleFunc("/api/posture-checks", defaultHandler.createPostureCheck).Methods("POST")
|
||||||
router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.UpdatePostureCheck).Methods("PUT")
|
router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.updatePostureCheck).Methods("PUT")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package routes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
@ -23,15 +24,24 @@ import (
|
|||||||
const maxDomains = 32
|
const maxDomains = 32
|
||||||
const failedToConvertRoute = "failed to convert route to response: %v"
|
const failedToConvertRoute = "failed to convert route to response: %v"
|
||||||
|
|
||||||
// RoutesHandler is the routes handler of the account
|
// handler is the routes handler of the account
|
||||||
type RoutesHandler struct {
|
type handler struct {
|
||||||
accountManager server.AccountManager
|
accountManager server.AccountManager
|
||||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRoutesHandler returns a new instance of RoutesHandler handler
|
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||||
func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *RoutesHandler {
|
routesHandler := newHandler(accountManager, authCfg)
|
||||||
return &RoutesHandler{
|
router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS")
|
||||||
|
router.HandleFunc("/routes/{routeId}", routesHandler.getRoute).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/routes/{routeId}", routesHandler.deleteRoute).Methods("DELETE", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
// newHandler returns a new instance of routes handler
|
||||||
|
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
|
||||||
|
return &handler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithAudience(authCfg.Audience),
|
jwtclaims.WithAudience(authCfg.Audience),
|
||||||
@ -40,8 +50,8 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllRoutes returns the list of routes for the account
|
// getAllRoutes returns the list of routes for the account
|
||||||
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -67,8 +77,8 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, apiRoutes)
|
util.WriteJSONObject(r.Context(), w, apiRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateRoute handles route creation request
|
// createRoute handles route creation request
|
||||||
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -139,7 +149,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, routes)
|
util.WriteJSONObject(r.Context(), w, routes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) error {
|
func (h *handler) validateRoute(req api.PostApiRoutesJSONRequestBody) error {
|
||||||
if req.Network != nil && req.Domains != nil {
|
if req.Network != nil && req.Domains != nil {
|
||||||
return status.Errorf(status.InvalidArgument, "only one of 'network' or 'domains' should be provided")
|
return status.Errorf(status.InvalidArgument, "only one of 'network' or 'domains' should be provided")
|
||||||
}
|
}
|
||||||
@ -164,8 +174,8 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoute handles update to a route identified by a given ID
|
// updateRoute handles update to a route identified by a given ID
|
||||||
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -257,8 +267,8 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, routes)
|
util.WriteJSONObject(r.Context(), w, routes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteRoute handles route deletion request
|
// deleteRoute handles route deletion request
|
||||||
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -278,11 +288,11 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRoute handles a route Get request identified by ID
|
// getRoute handles a route Get request identified by ID
|
||||||
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package routes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -87,8 +87,8 @@ var testingAccount = &server.Account{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func initRoutesTestData() *RoutesHandler {
|
func initRoutesTestData() *handler {
|
||||||
return &RoutesHandler{
|
return &handler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) {
|
GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) {
|
||||||
if routeID == existingRouteID {
|
if routeID == existingRouteID {
|
||||||
@ -152,7 +152,7 @@ func initRoutesTestData() *RoutesHandler {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
//return testingAccount, testingAccount.Users["test_user"], nil
|
// return testingAccount, testingAccount.Users["test_user"], nil
|
||||||
return testingAccount.Id, testingAccount.Users["test_user"].Id, nil
|
return testingAccount.Id, testingAccount.Users["test_user"].Id, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -521,10 +521,10 @@ func TestRoutesHandlers(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/routes/{routeId}", p.GetRoute).Methods("GET")
|
router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET")
|
||||||
router.HandleFunc("/api/routes/{routeId}", p.DeleteRoute).Methods("DELETE")
|
router.HandleFunc("/api/routes/{routeId}", p.deleteRoute).Methods("DELETE")
|
||||||
router.HandleFunc("/api/routes", p.CreateRoute).Methods("POST")
|
router.HandleFunc("/api/routes", p.createRoute).Methods("POST")
|
||||||
router.HandleFunc("/api/routes/{routeId}", p.UpdateRoute).Methods("PUT")
|
router.HandleFunc("/api/routes/{routeId}", p.updateRoute).Methods("PUT")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package setup_keys
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@ -10,20 +10,30 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SetupKeysHandler is a handler that returns a list of setup keys of the account
|
// handler is a handler that returns a list of setup keys of the account
|
||||||
type SetupKeysHandler struct {
|
type handler struct {
|
||||||
accountManager server.AccountManager
|
accountManager server.AccountManager
|
||||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSetupKeysHandler creates a new SetupKeysHandler HTTP handler
|
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||||
func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) *SetupKeysHandler {
|
keysHandler := newHandler(accountManager, authCfg)
|
||||||
return &SetupKeysHandler{
|
router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/setup-keys/{keyId}", keysHandler.updateSetupKey).Methods("PUT", "OPTIONS")
|
||||||
|
router.HandleFunc("/setup-keys/{keyId}", keysHandler.deleteSetupKey).Methods("DELETE", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
// newHandler creates a new setup key handler
|
||||||
|
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
|
||||||
|
return &handler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithAudience(authCfg.Audience),
|
jwtclaims.WithAudience(authCfg.Audience),
|
||||||
@ -32,8 +42,8 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateSetupKey is a POST requests that creates a new SetupKey
|
// createSetupKey is a POST requests that creates a new SetupKey
|
||||||
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -89,8 +99,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
|
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSetupKey is a GET request to get a SetupKey by ID
|
// getSetupKey is a GET request to get a SetupKey by ID
|
||||||
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -114,8 +124,8 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
|
|||||||
writeSuccess(r.Context(), w, key)
|
writeSuccess(r.Context(), w, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSetupKey is a PUT request to update server.SetupKey
|
// updateSetupKey is a PUT request to update server.SetupKey
|
||||||
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -155,8 +165,8 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
writeSuccess(r.Context(), w, newKey)
|
writeSuccess(r.Context(), w, newKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllSetupKeys is a GET request that returns a list of SetupKey
|
// getAllSetupKeys is a GET request that returns a list of SetupKey
|
||||||
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -178,7 +188,7 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques
|
|||||||
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
|
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -199,7 +209,7 @@ func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) {
|
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) {
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package setup_keys
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -26,12 +26,13 @@ const (
|
|||||||
newSetupKeyName = "New Setup Key"
|
newSetupKeyName = "New Setup Key"
|
||||||
updatedSetupKeyName = "KKKey"
|
updatedSetupKeyName = "KKKey"
|
||||||
notFoundSetupKeyID = "notFoundSetupKeyID"
|
notFoundSetupKeyID = "notFoundSetupKeyID"
|
||||||
|
testAccountID = "test_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey,
|
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey,
|
||||||
user *server.User,
|
user *server.User,
|
||||||
) *SetupKeysHandler {
|
) *handler {
|
||||||
return &SetupKeysHandler{
|
return &handler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return claims.AccountId, claims.UserId, nil
|
return claims.AccountId, claims.UserId, nil
|
||||||
@ -178,11 +179,11 @@ func TestSetupKeysHandlers(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/setup-keys", handler.GetAllSetupKeys).Methods("GET", "OPTIONS")
|
router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS")
|
||||||
router.HandleFunc("/api/setup-keys", handler.CreateSetupKey).Methods("POST", "OPTIONS")
|
router.HandleFunc("/api/setup-keys", handler.createSetupKey).Methods("POST", "OPTIONS")
|
||||||
router.HandleFunc("/api/setup-keys/{keyId}", handler.GetSetupKey).Methods("GET", "OPTIONS")
|
router.HandleFunc("/api/setup-keys/{keyId}", handler.getSetupKey).Methods("GET", "OPTIONS")
|
||||||
router.HandleFunc("/api/setup-keys/{keyId}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS")
|
router.HandleFunc("/api/setup-keys/{keyId}", handler.updateSetupKey).Methods("PUT", "OPTIONS")
|
||||||
router.HandleFunc("/api/setup-keys/{keyId}", handler.DeleteSetupKey).Methods("DELETE", "OPTIONS")
|
router.HandleFunc("/api/setup-keys/{keyId}", handler.deleteSetupKey).Methods("DELETE", "OPTIONS")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package users
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -9,20 +9,29 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PATHandler is the nameserver group handler of the account
|
// patHandler is the nameserver group handler of the account
|
||||||
type PATHandler struct {
|
type patHandler struct {
|
||||||
accountManager server.AccountManager
|
accountManager server.AccountManager
|
||||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPATsHandler creates a new PATHandler HTTP handler
|
func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||||
func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATHandler {
|
tokenHandler := newPATsHandler(accountManager, authCfg)
|
||||||
return &PATHandler{
|
router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS")
|
||||||
|
}
|
||||||
|
|
||||||
|
// newPATsHandler creates a new patHandler HTTP handler
|
||||||
|
func newPATsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *patHandler {
|
||||||
|
return &patHandler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithAudience(authCfg.Audience),
|
jwtclaims.WithAudience(authCfg.Audience),
|
||||||
@ -31,8 +40,8 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
||||||
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
|
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -61,8 +70,8 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, patResponse)
|
util.WriteJSONObject(r.Context(), w, patResponse)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetToken is HTTP GET handler that returns a personal access token for the given user
|
// getToken is HTTP GET handler that returns a personal access token for the given user
|
||||||
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -92,8 +101,8 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, toPATResponse(pat))
|
util.WriteJSONObject(r.Context(), w, toPATResponse(pat))
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateToken is HTTP POST handler that creates a personal access token for the given user
|
// createToken is HTTP POST handler that creates a personal access token for the given user
|
||||||
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -124,8 +133,8 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, toPATGeneratedResponse(pat))
|
util.WriteJSONObject(r.Context(), w, toPATGeneratedResponse(pat))
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
||||||
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
|
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
|
||||||
claims := h.claimsExtractor.FromRequestContext(r)
|
claims := h.claimsExtractor.FromRequestContext(r)
|
||||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -152,7 +161,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken {
|
func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken {
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package users
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -61,8 +61,8 @@ var testAccount = &server.Account{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func initPATTestData() *PATHandler {
|
func initPATTestData() *patHandler {
|
||||||
return &PATHandler{
|
return &patHandler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||||
if accountID != existingAccountID {
|
if accountID != existingAccountID {
|
||||||
@ -186,10 +186,10 @@ func TestTokenHandlers(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/users/{userId}/tokens", p.GetAllTokens).Methods("GET")
|
router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET")
|
||||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.GetToken).Methods("GET")
|
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET")
|
||||||
router.HandleFunc("/api/users/{userId}/tokens", p.CreateToken).Methods("POST")
|
router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST")
|
||||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.DeleteToken).Methods("DELETE")
|
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package users
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@ -9,6 +9,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/api"
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||||
"github.com/netbirdio/netbird/management/server/http/util"
|
"github.com/netbirdio/netbird/management/server/http/util"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
|
|
||||||
@ -16,15 +17,25 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
)
|
)
|
||||||
|
|
||||||
// UsersHandler is a handler that returns users of the account
|
// handler is a handler that returns users of the account
|
||||||
type UsersHandler struct {
|
type handler struct {
|
||||||
accountManager server.AccountManager
|
accountManager server.AccountManager
|
||||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUsersHandler creates a new UsersHandler HTTP handler
|
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||||
func NewUsersHandler(accountManager server.AccountManager, authCfg AuthCfg) *UsersHandler {
|
userHandler := newHandler(accountManager, authCfg)
|
||||||
return &UsersHandler{
|
router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS")
|
||||||
|
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS")
|
||||||
|
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
|
||||||
|
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
|
||||||
|
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
|
||||||
|
addUsersTokensEndpoint(accountManager, authCfg, router)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newHandler creates a new UsersHandler HTTP handler
|
||||||
|
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
|
||||||
|
return &handler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithAudience(authCfg.Audience),
|
jwtclaims.WithAudience(authCfg.Audience),
|
||||||
@ -33,8 +44,8 @@ func NewUsersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Use
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateUser is a PUT requests to update User data
|
// updateUser is a PUT requests to update User data
|
||||||
func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodPut {
|
if r.Method != http.MethodPut {
|
||||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||||
return
|
return
|
||||||
@ -94,8 +105,8 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
|
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteUser is a DELETE request to delete a user
|
// deleteUser is a DELETE request to delete a user
|
||||||
func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodDelete {
|
if r.Method != http.MethodDelete {
|
||||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||||
return
|
return
|
||||||
@ -121,11 +132,11 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateUser creates a User in the system with a status "invited" (effectively this is a user invite).
|
// createUser creates a User in the system with a status "invited" (effectively this is a user invite).
|
||||||
func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||||
return
|
return
|
||||||
@ -175,9 +186,9 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
|
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllUsers returns a list of users of the account this user belongs to.
|
// getAllUsers returns a list of users of the account this user belongs to.
|
||||||
// It also gathers additional user data (like email and name) from the IDP manager.
|
// It also gathers additional user data (like email and name) from the IDP manager.
|
||||||
func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodGet {
|
if r.Method != http.MethodGet {
|
||||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||||
return
|
return
|
||||||
@ -222,9 +233,9 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, users)
|
util.WriteJSONObject(r.Context(), w, users)
|
||||||
}
|
}
|
||||||
|
|
||||||
// InviteUser resend invitations to users who haven't activated their accounts,
|
// inviteUser resend invitations to users who haven't activated their accounts,
|
||||||
// prior to the expiration period.
|
// prior to the expiration period.
|
||||||
func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||||
return
|
return
|
||||||
@ -250,7 +261,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, emptyObject{})
|
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
@ -1,4 +1,4 @@
|
|||||||
package http
|
package users
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@ -61,8 +61,8 @@ var usersTestAccount = &server.Account{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func initUsersTestData() *UsersHandler {
|
func initUsersTestData() *handler {
|
||||||
return &UsersHandler{
|
return &handler{
|
||||||
accountManager: &mock_server.MockAccountManager{
|
accountManager: &mock_server.MockAccountManager{
|
||||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||||
return usersTestAccount.Id, claims.UserId, nil
|
return usersTestAccount.Id, claims.UserId, nil
|
||||||
@ -147,7 +147,7 @@ func TestGetUsers(t *testing.T) {
|
|||||||
requestPath string
|
requestPath string
|
||||||
expectedUserIDs []string
|
expectedUserIDs []string
|
||||||
}{
|
}{
|
||||||
{name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}},
|
{name: "getAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}},
|
||||||
{name: "GetOnlyServiceUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=true", expectedStatus: http.StatusOK, expectedUserIDs: []string{serviceUserID}},
|
{name: "GetOnlyServiceUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=true", expectedStatus: http.StatusOK, expectedUserIDs: []string{serviceUserID}},
|
||||||
{name: "GetOnlyRegularUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=false", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID}},
|
{name: "GetOnlyRegularUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=false", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID}},
|
||||||
}
|
}
|
||||||
@ -159,7 +159,7 @@ func TestGetUsers(t *testing.T) {
|
|||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||||
|
|
||||||
userHandler.GetAllUsers(recorder, req)
|
userHandler.getAllUsers(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
@ -265,7 +265,7 @@ func TestUpdateUser(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
|
|
||||||
router := mux.NewRouter()
|
router := mux.NewRouter()
|
||||||
router.HandleFunc("/api/users/{userId}", userHandler.UpdateUser).Methods("PUT")
|
router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT")
|
||||||
router.ServeHTTP(recorder, req)
|
router.ServeHTTP(recorder, req)
|
||||||
|
|
||||||
res := recorder.Result()
|
res := recorder.Result()
|
||||||
@ -356,7 +356,7 @@ func TestCreateUser(t *testing.T) {
|
|||||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
userHandler.CreateUser(rr, req)
|
userHandler.createUser(rr, req)
|
||||||
|
|
||||||
res := rr.Result()
|
res := rr.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
@ -401,7 +401,7 @@ func TestInviteUser(t *testing.T) {
|
|||||||
req = mux.SetURLVars(req, tc.requestVars)
|
req = mux.SetURLVars(req, tc.requestVars)
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
userHandler.InviteUser(rr, req)
|
userHandler.inviteUser(rr, req)
|
||||||
|
|
||||||
res := rr.Result()
|
res := rr.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
@ -454,7 +454,7 @@ func TestDeleteUser(t *testing.T) {
|
|||||||
req = mux.SetURLVars(req, tc.requestVars)
|
req = mux.SetURLVars(req, tc.requestVars)
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
userHandler.DeleteUser(rr, req)
|
userHandler.deleteUser(rr, req)
|
||||||
|
|
||||||
res := rr.Result()
|
res := rr.Result()
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
@ -14,6 +14,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// EmptyObject is an empty struct used to return empty JSON object
|
||||||
|
type EmptyObject struct {
|
||||||
|
}
|
||||||
|
|
||||||
type ErrorResponse struct {
|
type ErrorResponse struct {
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Code int `json:"code"`
|
Code int `json:"code"`
|
||||||
|
@ -740,7 +740,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
|||||||
// it means that the client has already checked if it needs login and had been through the SSO flow
|
// it means that the client has already checked if it needs login and had been through the SSO flow
|
||||||
// so, we can skip this check and directly proceed with the login
|
// so, we can skip this check and directly proceed with the login
|
||||||
if login.UserID == "" {
|
if login.UserID == "" {
|
||||||
log.Info("Peer needs login")
|
|
||||||
err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login)
|
err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
|
48
util/semaphore-group/semaphore_group.go
Normal file
48
util/semaphore-group/semaphore_group.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package semaphoregroup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SemaphoreGroup is a custom type that combines sync.WaitGroup and a semaphore.
|
||||||
|
type SemaphoreGroup struct {
|
||||||
|
waitGroup sync.WaitGroup
|
||||||
|
semaphore chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSemaphoreGroup creates a new SemaphoreGroup with the specified semaphore limit.
|
||||||
|
func NewSemaphoreGroup(limit int) *SemaphoreGroup {
|
||||||
|
return &SemaphoreGroup{
|
||||||
|
semaphore: make(chan struct{}, limit),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add increments the internal WaitGroup counter and acquires a semaphore slot.
|
||||||
|
func (sg *SemaphoreGroup) Add(ctx context.Context) {
|
||||||
|
sg.waitGroup.Add(1)
|
||||||
|
|
||||||
|
// Acquire semaphore slot
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case sg.semaphore <- struct{}{}:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Done decrements the internal WaitGroup counter and releases a semaphore slot.
|
||||||
|
func (sg *SemaphoreGroup) Done(ctx context.Context) {
|
||||||
|
sg.waitGroup.Done()
|
||||||
|
|
||||||
|
// Release semaphore slot
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-sg.semaphore:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait waits until the internal WaitGroup counter is zero.
|
||||||
|
func (sg *SemaphoreGroup) Wait() {
|
||||||
|
sg.waitGroup.Wait()
|
||||||
|
}
|
66
util/semaphore-group/semaphore_group_test.go
Normal file
66
util/semaphore-group/semaphore_group_test.go
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
package semaphoregroup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSemaphoreGroup(t *testing.T) {
|
||||||
|
semGroup := NewSemaphoreGroup(2)
|
||||||
|
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
semGroup.Add(context.Background())
|
||||||
|
go func(id int) {
|
||||||
|
defer semGroup.Done(context.Background())
|
||||||
|
|
||||||
|
got := len(semGroup.semaphore)
|
||||||
|
if got == 0 {
|
||||||
|
t.Errorf("Expected semaphore length > 0 , got 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
t.Logf("Goroutine %d is running\n", id)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
semGroup.Wait()
|
||||||
|
|
||||||
|
want := 0
|
||||||
|
got := len(semGroup.semaphore)
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("Expected semaphore length %d, got %d", want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSemaphoreGroupContext(t *testing.T) {
|
||||||
|
semGroup := NewSemaphoreGroup(1)
|
||||||
|
semGroup.Add(context.Background())
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
rChan := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
semGroup.Add(ctx)
|
||||||
|
rChan <- struct{}{}
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-rChan:
|
||||||
|
case <-time.NewTimer(2 * time.Second).C:
|
||||||
|
t.Error("Adding to semaphore group should not block when context is not done")
|
||||||
|
}
|
||||||
|
|
||||||
|
semGroup.Done(context.Background())
|
||||||
|
|
||||||
|
ctxDone, cancelDone := context.WithTimeout(context.Background(), 1*time.Second)
|
||||||
|
t.Cleanup(cancelDone)
|
||||||
|
go func() {
|
||||||
|
semGroup.Done(ctxDone)
|
||||||
|
rChan <- struct{}{}
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-rChan:
|
||||||
|
case <-time.NewTimer(2 * time.Second).C:
|
||||||
|
t.Error("Releasing from semaphore group should not block when context is not done")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user