diff --git a/client/internal/config.go b/client/internal/config.go index ce87835cd..998690ef1 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -46,6 +46,7 @@ type ConfigInput struct { ManagementURL string AdminURL string ConfigPath string + StateFilePath string PreSharedKey *string ServerSSHAllowed *bool NATExternalIPs []string @@ -105,10 +106,10 @@ type Config struct { // DNSRouteInterval is the interval in which the DNS routes are updated DNSRouteInterval time.Duration - //Path to a certificate used for mTLS authentication + // Path to a certificate used for mTLS authentication ClientCertPath string - //Path to corresponding private key of ClientCertPath + // Path to corresponding private key of ClientCertPath ClientCertKeyPath string ClientCertKeyPair *tls.Certificate `json:"-"` @@ -116,7 +117,7 @@ type Config struct { // ReadConfig read config file and return with Config. If it is not exists create a new with default values func ReadConfig(configPath string) (*Config, error) { - if configFileIsExists(configPath) { + if fileExists(configPath) { err := util.EnforcePermission(configPath) if err != nil { log.Errorf("failed to enforce permission on config dir: %v", err) @@ -149,7 +150,7 @@ func ReadConfig(configPath string) (*Config, error) { // UpdateConfig update existing configuration according to input configuration and return with the configuration func UpdateConfig(input ConfigInput) (*Config, error) { - if !configFileIsExists(input.ConfigPath) { + if !fileExists(input.ConfigPath) { return nil, status.Errorf(codes.NotFound, "config file doesn't exist") } @@ -158,7 +159,7 @@ func UpdateConfig(input ConfigInput) (*Config, error) { // UpdateOrCreateConfig reads existing config or generates a new one func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { - if !configFileIsExists(input.ConfigPath) { + if !fileExists(input.ConfigPath) { log.Infof("generating new config %s", input.ConfigPath) cfg, err := createNewConfig(input) if err != nil { @@ -472,11 +473,19 @@ func isPreSharedKeyHidden(preSharedKey *string) bool { return false } -func configFileIsExists(path string) bool { +func fileExists(path string) bool { _, err := os.Stat(path) return !os.IsNotExist(err) } +func createFile(path string) error { + file, err := os.Create(path) + if err != nil { + return err + } + return file.Close() +} + // UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain. // If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config. // The check is performed only for the NetBird's managed version. diff --git a/client/internal/connect.go b/client/internal/connect.go index 4848b1c11..782984e27 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -91,6 +91,7 @@ func (c *ConnectClient) RunOniOS( fileDescriptor int32, networkChangeListener listener.NetworkChangeListener, dnsManager dns.IosDnsManager, + stateFilePath string, ) error { // Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension. debug.SetGCPercent(5) @@ -99,6 +100,7 @@ func (c *ConnectClient) RunOniOS( FileDescriptor: fileDescriptor, NetworkChangeListener: networkChangeListener, DnsManager: dnsManager, + StateFilePath: stateFilePath, } return c.run(mobileDependency, nil, nil) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 782bb48bb..34219def1 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -39,6 +39,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" @@ -62,6 +63,7 @@ import ( const ( PeerConnectionTimeoutMax = 45000 // ms PeerConnectionTimeoutMin = 30000 // ms + connInitLimit = 200 ) var ErrResetConnection = fmt.Errorf("reset connection") @@ -177,6 +179,7 @@ type Engine struct { // Network map persistence persistNetworkMap bool latestNetworkMap *mgmProto.NetworkMap + connSemaphore *semaphoregroup.SemaphoreGroup } // Peer is an instance of the Connection Peer @@ -242,6 +245,18 @@ func NewEngineWithProbes( statusRecorder: statusRecorder, probes: probes, checks: checks, + connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), + } + if runtime.GOOS == "ios" { + if !fileExists(mobileDep.StateFilePath) { + err := createFile(mobileDep.StateFilePath) + if err != nil { + log.Errorf("failed to create state file: %v", err) + // we are not exiting as we can run without the state manager + } + } + + engine.stateManager = statemanager.New(mobileDep.StateFilePath) } if path := statemanager.GetDefaultStatePath(); path != "" { engine.stateManager = statemanager.New(path) @@ -1040,7 +1055,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e }, } - peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher) + peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher, e.connSemaphore) if err != nil { return nil, err } diff --git a/client/internal/mobile_dependency.go b/client/internal/mobile_dependency.go index 2b0c92cc6..4ac0fc141 100644 --- a/client/internal/mobile_dependency.go +++ b/client/internal/mobile_dependency.go @@ -19,4 +19,5 @@ type MobileDependency struct { // iOS only DnsManager dns.IosDnsManager FileDescriptor int32 + StateFilePath string } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 3a698a82a..5c2e2cb60 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -23,6 +23,7 @@ import ( relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" + semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) type ConnPriority int @@ -104,12 +105,13 @@ type Conn struct { wgProxyICE wgproxy.Proxy wgProxyRelay wgproxy.Proxy - guard *guard.Guard + guard *guard.Guard + semaphore *semaphoregroup.SemaphoreGroup } // NewConn creates a new not opened Conn to the remote peer. // To establish a connection run Conn.Open -func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) { +func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) { allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps) if err != nil { log.Errorf("failed to parse allowedIPS: %v", err) @@ -130,6 +132,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu allowedIP: allowedIP, statusRelay: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(), + semaphore: semaphore, } rFns := WorkerRelayCallbacks{ @@ -169,6 +172,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu // It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will // be used. func (conn *Conn) Open() { + conn.semaphore.Add(conn.ctx) conn.log.Debugf("open connection to peer") conn.mu.Lock() @@ -191,6 +195,7 @@ func (conn *Conn) Open() { } func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) { + defer conn.semaphore.Done(conn.ctx) conn.waitInitialRandomSleepTime(ctx) err := conn.handshaker.sendOffer() diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index 039952588..b3e9d5b60 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/util" + semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) var connConf = ConnConfig{ @@ -46,7 +47,7 @@ func TestNewConn_interfaceFilter(t *testing.T) { func TestConn_GetKey(t *testing.T) { swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) - conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher) + conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) if err != nil { return } @@ -58,7 +59,7 @@ func TestConn_GetKey(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) { swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) if err != nil { return } @@ -92,7 +93,7 @@ func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) { swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) if err != nil { return } @@ -125,7 +126,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) { } func TestConn_Status(t *testing.T) { swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) if err != nil { return } diff --git a/client/internal/routeselector/routeselector_test.go b/client/internal/routeselector/routeselector_test.go index 7df433f92..b1671f254 100644 --- a/client/internal/routeselector/routeselector_test.go +++ b/client/internal/routeselector/routeselector_test.go @@ -273,3 +273,88 @@ func TestRouteSelector_FilterSelected(t *testing.T) { "route2|192.168.0.0/16": {}, }, filtered) } + +func TestRouteSelector_NewRoutesBehavior(t *testing.T) { + initialRoutes := []route.NetID{"route1", "route2", "route3"} + newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"} + + tests := []struct { + name string + initialState func(rs *routeselector.RouteSelector) error // Setup initial state + wantNewSelected []route.NetID // Expected selected routes after new routes appear + }{ + { + name: "New routes with initial selectAll state", + initialState: func(rs *routeselector.RouteSelector) error { + rs.SelectAllRoutes() + return nil + }, + // When selectAll is true, all routes including new ones should be selected + wantNewSelected: []route.NetID{"route1", "route2", "route3", "route4", "route5"}, + }, + { + name: "New routes after specific selection", + initialState: func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, initialRoutes) + }, + // When specific routes were selected, new routes should remain unselected + wantNewSelected: []route.NetID{"route1", "route2"}, + }, + { + name: "New routes after deselect all", + initialState: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + // After deselect all, new routes should remain unselected + wantNewSelected: []route.NetID{}, + }, + { + name: "New routes after deselecting specific routes", + initialState: func(rs *routeselector.RouteSelector) error { + rs.SelectAllRoutes() + return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes) + }, + // After deselecting specific routes, new routes should remain unselected + wantNewSelected: []route.NetID{"route2", "route3"}, + }, + { + name: "New routes after selecting with append", + initialState: func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1"}, true, initialRoutes) + }, + // When routes were appended, new routes should remain unselected + wantNewSelected: []route.NetID{"route1"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := routeselector.NewRouteSelector() + + // Setup initial state + err := tt.initialState(rs) + require.NoError(t, err) + + // Verify selection state with new routes + for _, id := range newRoutes { + assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantNewSelected, id), + "Route %s selection state incorrect", id) + } + + // Additional verification using FilterSelected + routes := route.HAMap{ + "route1|10.0.0.0/8": {}, + "route2|192.168.0.0/16": {}, + "route3|172.16.0.0/12": {}, + "route4|10.10.0.0/16": {}, + "route5|192.168.1.0/24": {}, + } + + filtered := rs.FilterSelected(routes) + expectedLen := len(tt.wantNewSelected) + assert.Equal(t, expectedLen, len(filtered), + "FilterSelected returned wrong number of routes, got %d want %d", len(filtered), expectedLen) + }) + } +} diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 9d65bdbe0..6f501e0c6 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -59,6 +59,7 @@ func init() { // Client struct manage the life circle of background service type Client struct { cfgFile string + stateFile string recorder *peer.Status ctxCancel context.CancelFunc ctxCancelLock *sync.Mutex @@ -73,9 +74,10 @@ type Client struct { } // NewClient instantiate a new Client -func NewClient(cfgFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client { +func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client { return &Client{ cfgFile: cfgFile, + stateFile: stateFile, deviceName: deviceName, osName: osName, osVersion: osVersion, @@ -91,7 +93,8 @@ func (c *Client) Run(fd int32, interfaceName string) error { log.Infof("Starting NetBird client") log.Debugf("Tunnel uses interface: %s", interfaceName) cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{ - ConfigPath: c.cfgFile, + ConfigPath: c.cfgFile, + StateFilePath: c.stateFile, }) if err != nil { return err @@ -124,7 +127,7 @@ func (c *Client) Run(fd int32, interfaceName string) error { cfg.WgIface = interfaceName c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager) + return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile) } // Stop the internal client and free the resources diff --git a/client/ios/NetBirdSDK/preferences.go b/client/ios/NetBirdSDK/preferences.go index b78146679..5a0abd9a7 100644 --- a/client/ios/NetBirdSDK/preferences.go +++ b/client/ios/NetBirdSDK/preferences.go @@ -10,9 +10,10 @@ type Preferences struct { } // NewPreferences create new Preferences instance -func NewPreferences(configPath string) *Preferences { +func NewPreferences(configPath string, stateFilePath string) *Preferences { ci := internal.ConfigInput{ - ConfigPath: configPath, + ConfigPath: configPath, + StateFilePath: stateFilePath, } return &Preferences{ci} } diff --git a/client/ios/NetBirdSDK/preferences_test.go b/client/ios/NetBirdSDK/preferences_test.go index aa6a475ae..7e5325a00 100644 --- a/client/ios/NetBirdSDK/preferences_test.go +++ b/client/ios/NetBirdSDK/preferences_test.go @@ -9,7 +9,8 @@ import ( func TestPreferences_DefaultValues(t *testing.T) { cfgFile := filepath.Join(t.TempDir(), "netbird.json") - p := NewPreferences(cfgFile) + stateFile := filepath.Join(t.TempDir(), "state.json") + p := NewPreferences(cfgFile, stateFile) defaultVar, err := p.GetAdminURL() if err != nil { t.Fatalf("failed to read default value: %s", err) @@ -42,7 +43,8 @@ func TestPreferences_DefaultValues(t *testing.T) { func TestPreferences_ReadUncommitedValues(t *testing.T) { exampleString := "exampleString" cfgFile := filepath.Join(t.TempDir(), "netbird.json") - p := NewPreferences(cfgFile) + stateFile := filepath.Join(t.TempDir(), "state.json") + p := NewPreferences(cfgFile, stateFile) p.SetAdminURL(exampleString) resp, err := p.GetAdminURL() @@ -79,7 +81,8 @@ func TestPreferences_Commit(t *testing.T) { exampleURL := "https://myurl.com:443" examplePresharedKey := "topsecret" cfgFile := filepath.Join(t.TempDir(), "netbird.json") - p := NewPreferences(cfgFile) + stateFile := filepath.Join(t.TempDir(), "state.json") + p := NewPreferences(cfgFile, stateFile) p.SetAdminURL(exampleURL) p.SetManagementURL(exampleURL) @@ -90,7 +93,7 @@ func TestPreferences_Commit(t *testing.T) { t.Fatalf("failed to save changes: %s", err) } - p = NewPreferences(cfgFile) + p = NewPreferences(cfgFile, stateFile) resp, err := p.GetAdminURL() if err != nil { t.Fatalf("failed to read admin url: %s", err) diff --git a/management/cmd/management.go b/management/cmd/management.go index 719d1a78c..bfa158c5b 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -42,6 +42,7 @@ import ( nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" httpapi "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/metrics" @@ -257,7 +258,7 @@ var ( return fmt.Errorf("failed creating JWT validator: %v", err) } - httpAPIAuthCfg := httpapi.AuthCfg{ + httpAPIAuthCfg := configs.AuthCfg{ Issuer: config.HttpConfig.AuthIssuer, Audience: config.HttpConfig.AuthAudience, UserIDClaim: config.HttpConfig.AuthUserIDClaim, diff --git a/management/server/http/configs/auth.go b/management/server/http/configs/auth.go new file mode 100644 index 000000000..aa91fa55b --- /dev/null +++ b/management/server/http/configs/auth.go @@ -0,0 +1,9 @@ +package configs + +// AuthCfg contains parameters for authentication middleware +type AuthCfg struct { + Issuer string + Audience string + UserIDClaim string + KeysLocation string +} diff --git a/management/server/http/handler.go b/management/server/http/handler.go index c3928bff6..373aa4dd7 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -12,6 +12,16 @@ import ( s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/http/handlers/accounts" + "github.com/netbirdio/netbird/management/server/http/handlers/dns" + "github.com/netbirdio/netbird/management/server/http/handlers/events" + "github.com/netbirdio/netbird/management/server/http/handlers/groups" + "github.com/netbirdio/netbird/management/server/http/handlers/peers" + "github.com/netbirdio/netbird/management/server/http/handlers/policies" + "github.com/netbirdio/netbird/management/server/http/handlers/routes" + "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" + "github.com/netbirdio/netbird/management/server/http/handlers/users" "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/integrated_validator" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -20,27 +30,15 @@ import ( const apiPrefix = "/api" -// AuthCfg contains parameters for authentication middleware -type AuthCfg struct { - Issuer string - Audience string - UserIDClaim string - KeysLocation string -} - type apiHandler struct { Router *mux.Router AccountManager s.AccountManager geolocationManager *geolocation.Geolocation - AuthCfg AuthCfg -} - -// EmptyObject is an empty struct used to return empty JSON object -type emptyObject struct { + AuthCfg configs.AuthCfg } // APIHandler creates the Management service HTTP API handler registering all the available endpoints. -func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { +func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { claimsExtractor := jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), @@ -86,122 +84,15 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa return nil, fmt.Errorf("register integrations endpoints: %w", err) } - api.addAccountsEndpoint() - api.addPeersEndpoint() - api.addUsersEndpoint() - api.addUsersTokensEndpoint() - api.addSetupKeysEndpoint() - api.addPoliciesEndpoint() - api.addGroupsEndpoint() - api.addRoutesEndpoint() - api.addDNSNameserversEndpoint() - api.addDNSSettingEndpoint() - api.addEventsEndpoint() - api.addPostureCheckEndpoint() - api.addLocationsEndpoint() + accounts.AddEndpoints(api.AccountManager, authCfg, router) + peers.AddEndpoints(api.AccountManager, authCfg, router) + users.AddEndpoints(api.AccountManager, authCfg, router) + setup_keys.AddEndpoints(api.AccountManager, authCfg, router) + policies.AddEndpoints(api.AccountManager, api.geolocationManager, authCfg, router) + groups.AddEndpoints(api.AccountManager, authCfg, router) + routes.AddEndpoints(api.AccountManager, authCfg, router) + dns.AddEndpoints(api.AccountManager, authCfg, router) + events.AddEndpoints(api.AccountManager, authCfg, router) return rootRouter, nil } - -func (apiHandler *apiHandler) addAccountsEndpoint() { - accountsHandler := NewAccountsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/accounts/{accountId}", accountsHandler.UpdateAccount).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/accounts/{accountId}", accountsHandler.DeleteAccount).Methods("DELETE", "OPTIONS") - apiHandler.Router.HandleFunc("/accounts", accountsHandler.GetAllAccounts).Methods("GET", "OPTIONS") -} - -func (apiHandler *apiHandler) addPeersEndpoint() { - peersHandler := NewPeersHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). - Methods("GET", "PUT", "DELETE", "OPTIONS") - apiHandler.Router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") -} - -func (apiHandler *apiHandler) addUsersEndpoint() { - userHandler := NewUsersHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/users", userHandler.GetAllUsers).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}", userHandler.UpdateUser).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS") - apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/invite", userHandler.InviteUser).Methods("POST", "OPTIONS") -} - -func (apiHandler *apiHandler) addUsersTokensEndpoint() { - tokenHandler := NewPATsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.GetAllTokens).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.CreateToken).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.GetToken).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.DeleteToken).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addSetupKeysEndpoint() { - keysHandler := NewSetupKeysHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeys).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.DeleteSetupKey).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addPoliciesEndpoint() { - policiesHandler := NewPoliciesHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/policies", policiesHandler.GetAllPolicies).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/policies", policiesHandler.CreatePolicy).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.UpdatePolicy).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.GetPolicy).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.DeletePolicy).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addGroupsEndpoint() { - groupsHandler := NewGroupsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/groups", groupsHandler.GetAllGroups).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/groups", groupsHandler.CreateGroup).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.UpdateGroup).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.GetGroup).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.DeleteGroup).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addRoutesEndpoint() { - routesHandler := NewRoutesHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/routes", routesHandler.GetAllRoutes).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/routes", routesHandler.CreateRoute).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.UpdateRoute).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.GetRoute).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.DeleteRoute).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addDNSNameserversEndpoint() { - nameserversHandler := NewNameserversHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.GetAllNameservers).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.CreateNameserverGroup).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.UpdateNameserverGroup).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.GetNameserverGroup).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.DeleteNameserverGroup).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addDNSSettingEndpoint() { - dnsSettingsHandler := NewDNSSettingsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/dns/settings", dnsSettingsHandler.GetDNSSettings).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/settings", dnsSettingsHandler.UpdateDNSSettings).Methods("PUT", "OPTIONS") -} - -func (apiHandler *apiHandler) addEventsEndpoint() { - eventsHandler := NewEventsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/events", eventsHandler.GetAllEvents).Methods("GET", "OPTIONS") -} - -func (apiHandler *apiHandler) addPostureCheckEndpoint() { - postureCheckHandler := NewPostureChecksHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.GetAllPostureChecks).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.CreatePostureCheck).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.UpdatePostureCheck).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.GetPostureCheck).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.DeletePostureCheck).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addLocationsEndpoint() { - locationHandler := NewGeolocationsHandlerHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/locations/countries", locationHandler.GetAllCountries).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/locations/countries/{country}/cities", locationHandler.GetCitiesByCountry).Methods("GET", "OPTIONS") -} diff --git a/management/server/http/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go similarity index 79% rename from management/server/http/accounts_handler.go rename to management/server/http/handlers/accounts/accounts_handler.go index 4baf9c692..c95207777 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -1,4 +1,4 @@ -package http +package accounts import ( "encoding/json" @@ -10,20 +10,28 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// AccountsHandler is a handler that handles the server.Account HTTP endpoints -type AccountsHandler struct { +// handler is a handler that handles the server.Account HTTP endpoints +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewAccountsHandler creates a new AccountsHandler HTTP handler -func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *AccountsHandler { - return &AccountsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + accountsHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS") + router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS") + router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS") +} + +// newHandler creates a new handler HTTP handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -32,8 +40,8 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) * } } -// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. -func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) { +// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. +func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -51,8 +59,8 @@ func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } -// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) -func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) { +// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) +func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) _, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -111,8 +119,8 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) util.WriteJSONObject(r.Context(), w, &resp) } -// DeleteAccount is a HTTP DELETE handler to delete an account -func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) { +// deleteAccount is a HTTP DELETE handler to delete an account +func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) vars := mux.Vars(r) targetAccountID := vars["accountId"] @@ -127,7 +135,7 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } func toAccountResponse(accountID string, settings *server.Settings) *api.Account { diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go similarity index 97% rename from management/server/http/accounts_handler_test.go rename to management/server/http/handlers/accounts/accounts_handler_test.go index cacb3d430..9d7e8a84d 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -1,4 +1,4 @@ -package http +package accounts import ( "bytes" @@ -20,8 +20,8 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) -func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler { - return &AccountsHandler{ +func initAccountsTestData(account *server.Account, admin *server.User) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return account.Id, admin.Id, nil @@ -89,7 +89,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { requestBody io.Reader }{ { - name: "GetAllAccounts OK", + name: "getAllAccounts OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/accounts", @@ -189,8 +189,8 @@ func TestAccounts_AccountsHandler(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/accounts", handler.GetAllAccounts).Methods("GET") - router.HandleFunc("/api/accounts/{accountId}", handler.UpdateAccount).Methods("PUT") + router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET") + router.HandleFunc("/api/accounts/{accountId}", handler.updateAccount).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/dns_settings_handler.go b/management/server/http/handlers/dns/dns_settings_handler.go similarity index 62% rename from management/server/http/dns_settings_handler.go rename to management/server/http/handlers/dns/dns_settings_handler.go index 13c2101a7..7dd8c1fc1 100644 --- a/management/server/http/dns_settings_handler.go +++ b/management/server/http/handlers/dns/dns_settings_handler.go @@ -1,26 +1,39 @@ -package http +package dns import ( "encoding/json" "net/http" + "github.com/gorilla/mux" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" ) -// DNSSettingsHandler is a handler that returns the DNS settings of the account -type DNSSettingsHandler struct { +// dnsSettingsHandler is a handler that returns the DNS settings of the account +type dnsSettingsHandler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewDNSSettingsHandler returns a new instance of DNSSettingsHandler handler -func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg) *DNSSettingsHandler { - return &DNSSettingsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + addDNSSettingEndpoint(accountManager, authCfg, router) + addDNSNameserversEndpoint(accountManager, authCfg, router) +} + +func addDNSSettingEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + dnsSettingsHandler := newDNSSettingsHandler(accountManager, authCfg) + router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS") +} + +// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler +func newDNSSettingsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *dnsSettingsHandler { + return &dnsSettingsHandler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -29,8 +42,8 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg } } -// GetDNSSettings returns the DNS settings for the account -func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) { +// getDNSSettings returns the DNS settings for the account +func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -52,8 +65,8 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque util.WriteJSONObject(r.Context(), w, apiDNSSettings) } -// UpdateDNSSettings handles update to DNS settings of an account -func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { +// updateDNSSettings handles update to DNS settings of an account +func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/dns_settings_handler_test.go b/management/server/http/handlers/dns/dns_settings_handler_test.go similarity index 94% rename from management/server/http/dns_settings_handler_test.go rename to management/server/http/handlers/dns/dns_settings_handler_test.go index 8baea7b15..a64e3fd83 100644 --- a/management/server/http/dns_settings_handler_test.go +++ b/management/server/http/handlers/dns/dns_settings_handler_test.go @@ -1,4 +1,4 @@ -package http +package dns import ( "bytes" @@ -40,8 +40,8 @@ var testingDNSSettingsAccount = &server.Account{ DNSSettings: baseExistingDNSSettings, } -func initDNSSettingsTestData() *DNSSettingsHandler { - return &DNSSettingsHandler{ +func initDNSSettingsTestData() *dnsSettingsHandler { + return &dnsSettingsHandler{ accountManager: &mock_server.MockAccountManager{ GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) { return &testingDNSSettingsAccount.DNSSettings, nil @@ -120,8 +120,8 @@ func TestDNSSettingsHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/dns/settings", p.GetDNSSettings).Methods("GET") - router.HandleFunc("/api/dns/settings", p.UpdateDNSSettings).Methods("PUT") + router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET") + router.HandleFunc("/api/dns/settings", p.updateDNSSettings).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/nameservers_handler.go b/management/server/http/handlers/dns/nameservers_handler.go similarity index 77% rename from management/server/http/nameservers_handler.go rename to management/server/http/handlers/dns/nameservers_handler.go index e7a2bc2ae..09047e231 100644 --- a/management/server/http/nameservers_handler.go +++ b/management/server/http/handlers/dns/nameservers_handler.go @@ -1,4 +1,4 @@ -package http +package dns import ( "encoding/json" @@ -11,20 +11,30 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// NameserversHandler is the nameserver group handler of the account -type NameserversHandler struct { +// nameserversHandler is the nameserver group handler of the account +type nameserversHandler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewNameserversHandler returns a new instance of NameserversHandler handler -func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg) *NameserversHandler { - return &NameserversHandler{ +func addDNSNameserversEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + nameserversHandler := newNameserversHandler(accountManager, authCfg) + router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.getNameserverGroup).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.deleteNameserverGroup).Methods("DELETE", "OPTIONS") +} + +// newNameserversHandler returns a new instance of nameserversHandler handler +func newNameserversHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *nameserversHandler { + return &nameserversHandler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -33,8 +43,8 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg } } -// GetAllNameservers returns the list of nameserver groups for the account -func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) { +// getAllNameservers returns the list of nameserver groups for the account +func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -57,8 +67,8 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re util.WriteJSONObject(r.Context(), w, apiNameservers) } -// CreateNameserverGroup handles nameserver group creation request -func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) { +// createNameserverGroup handles nameserver group creation request +func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -90,8 +100,8 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, &resp) } -// UpdateNameserverGroup handles update to a nameserver group identified by a given ID -func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) { +// updateNameserverGroup handles update to a nameserver group identified by a given ID +func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -141,8 +151,8 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, &resp) } -// DeleteNameserverGroup handles nameserver group deletion request -func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) { +// deleteNameserverGroup handles nameserver group deletion request +func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -162,11 +172,11 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetNameserverGroup handles a nameserver group Get request identified by ID -func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) { +// getNameserverGroup handles a nameserver group Get request identified by ID +func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/handlers/dns/nameservers_handler_test.go similarity index 95% rename from management/server/http/nameservers_handler_test.go rename to management/server/http/handlers/dns/nameservers_handler_test.go index 98c2e402d..c6561e4d8 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/handlers/dns/nameservers_handler_test.go @@ -1,4 +1,4 @@ -package http +package dns import ( "bytes" @@ -50,8 +50,8 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{ Enabled: true, } -func initNameserversTestData() *NameserversHandler { - return &NameserversHandler{ +func initNameserversTestData() *nameserversHandler { + return &nameserversHandler{ accountManager: &mock_server.MockAccountManager{ GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { if nsGroupID == existingNSGroupID { @@ -206,10 +206,10 @@ func TestNameserversHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.GetNameserverGroup).Methods("GET") - router.HandleFunc("/api/dns/nameservers", p.CreateNameserverGroup).Methods("POST") - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.DeleteNameserverGroup).Methods("DELETE") - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.UpdateNameserverGroup).Methods("PUT") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET") + router.HandleFunc("/api/dns/nameservers", p.createNameserverGroup).Methods("POST") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.deleteNameserverGroup).Methods("DELETE") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.updateNameserverGroup).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/events_handler.go b/management/server/http/handlers/events/events_handler.go similarity index 79% rename from management/server/http/events_handler.go rename to management/server/http/handlers/events/events_handler.go index ee0c63f28..62da59535 100644 --- a/management/server/http/events_handler.go +++ b/management/server/http/handlers/events/events_handler.go @@ -1,28 +1,35 @@ -package http +package events import ( "context" "fmt" "net/http" + "github.com/gorilla/mux" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" ) -// EventsHandler HTTP handler -type EventsHandler struct { +// handler HTTP handler +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewEventsHandler creates a new EventsHandler HTTP handler -func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *EventsHandler { - return &EventsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + eventsHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS") +} + +// newHandler creates a new events handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -31,8 +38,8 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev } } -// GetAllEvents list of the given account -func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { +// getAllEvents list of the given account +func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -60,7 +67,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, events) } -func (h *EventsHandler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error { +func (h *handler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error { // build email, name maps based on users userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId) if err != nil { diff --git a/management/server/http/events_handler_test.go b/management/server/http/handlers/events/events_handler_test.go similarity index 97% rename from management/server/http/events_handler_test.go rename to management/server/http/handlers/events/events_handler_test.go index e525cf2ee..6af2e5346 100644 --- a/management/server/http/events_handler_test.go +++ b/management/server/http/handlers/events/events_handler_test.go @@ -1,4 +1,4 @@ -package http +package events import ( "context" @@ -20,8 +20,8 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" ) -func initEventsTestData(account string, events ...*activity.Event) *EventsHandler { - return &EventsHandler{ +func initEventsTestData(account string, events ...*activity.Event) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) { if accountID == account { @@ -183,7 +183,7 @@ func TestEvents_GetEvents(t *testing.T) { requestBody io.Reader }{ { - name: "GetAllEvents OK", + name: "getAllEvents OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/events/", @@ -201,7 +201,7 @@ func TestEvents_GetEvents(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/events/", handler.GetAllEvents).Methods("GET") + router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go similarity index 81% rename from management/server/http/groups_handler.go rename to management/server/http/handlers/groups/groups_handler.go index f369d1a00..e60529cec 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -1,13 +1,15 @@ -package http +package groups import ( "encoding/json" "net/http" "github.com/gorilla/mux" - nbpeer "github.com/netbirdio/netbird/management/server/peer" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/http/configs" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" @@ -16,15 +18,24 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) -// GroupsHandler is a handler that returns groups of the account -type GroupsHandler struct { +// handler is a handler that returns groups of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewGroupsHandler creates a new GroupsHandler HTTP handler -func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *GroupsHandler { - return &GroupsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + groupsHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS") + router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS") + router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS") + router.HandleFunc("/groups/{groupId}", groupsHandler.getGroup).Methods("GET", "OPTIONS") + router.HandleFunc("/groups/{groupId}", groupsHandler.deleteGroup).Methods("DELETE", "OPTIONS") +} + +// newHandler creates a new groups handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -33,8 +44,8 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr } } -// GetAllGroups list for the account -func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { +// getAllGroups list for the account +func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -63,8 +74,8 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, groupsResponse) } -// UpdateGroup handles update to a group identified by a given ID -func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { +// updateGroup handles update to a group identified by a given ID +func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -141,8 +152,8 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) } -// CreateGroup handles group creation request -func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { +// createGroup handles group creation request +func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -189,8 +200,8 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) } -// DeleteGroup handles group deletion request -func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { +// deleteGroup handles group deletion request +func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -215,11 +226,11 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetGroup returns a group -func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { +// getGroup returns a group +func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go similarity index 95% rename from management/server/http/groups_handler_test.go rename to management/server/http/handlers/groups/groups_handler_test.go index 7f3c81f18..089c1a40f 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -1,4 +1,4 @@ -package http +package groups import ( "bytes" @@ -31,8 +31,8 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler { - return &GroupsHandler{ +func initGroupTestData(initGroups ...*nbgroup.Group) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error { if !strings.HasPrefix(group.ID, "id-") { @@ -106,14 +106,14 @@ func TestGetGroup(t *testing.T) { requestBody io.Reader }{ { - name: "GetGroup OK", + name: "getGroup OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/groups/idofthegroup", expectedStatus: http.StatusOK, }, { - name: "GetGroup not found", + name: "getGroup not found", requestType: http.MethodGet, requestPath: "/api/groups/notexists", expectedStatus: http.StatusNotFound, @@ -133,7 +133,7 @@ func TestGetGroup(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/groups/{groupId}", p.GetGroup).Methods("GET") + router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -254,8 +254,8 @@ func TestWriteGroup(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/groups", p.CreateGroup).Methods("POST") - router.HandleFunc("/api/groups/{groupId}", p.UpdateGroup).Methods("PUT") + router.HandleFunc("/api/groups", p.createGroup).Methods("POST") + router.HandleFunc("/api/groups/{groupId}", p.updateGroup).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -331,7 +331,7 @@ func TestDeleteGroup(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) router := mux.NewRouter() - router.HandleFunc("/api/groups/{groupId}", p.DeleteGroup).Methods("DELETE") + router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go similarity index 88% rename from management/server/http/peers_handler.go rename to management/server/http/handlers/peers/peers_handler.go index f5027cd77..c53cbc038 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -1,4 +1,4 @@ -package http +package peers import ( "context" @@ -12,21 +12,30 @@ import ( "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" ) -// PeersHandler is a handler that returns peers of the account -type PeersHandler struct { +// Handler is a handler that returns peers of the account +type Handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPeersHandler creates a new PeersHandler HTTP handler -func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *PeersHandler { - return &PeersHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + peersHandler := NewHandler(accountManager, authCfg) + router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") + router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). + Methods("GET", "PUT", "DELETE", "OPTIONS") + router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") +} + +// NewHandler creates a new peers Handler +func NewHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *Handler { + return &Handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -35,7 +44,7 @@ func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Pee } } -func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { +func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { peerToReturn := peer.Copy() if peer.Status.Connected { // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected @@ -48,7 +57,7 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) return peerToReturn, nil } -func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { +func (h *Handler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID) if err != nil { util.WriteError(ctx, err, w) @@ -75,7 +84,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) } -func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *Handler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -120,18 +129,18 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid)) } -func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { +func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID) if err != nil { log.WithContext(ctx).Errorf("failed to delete peer: %v", err) util.WriteError(ctx, err, w) return } - util.WriteJSONObject(ctx, w, emptyObject{}) + util.WriteJSONObject(ctx, w, util.EmptyObject{}) } // HandlePeer handles all peer requests for GET, PUT and DELETE operations -func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { +func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -168,7 +177,7 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { } // GetAllPeers returns a list of all peers associated with a provided account -func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { +func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -219,7 +228,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, respBody) } -func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { +func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { for _, peer := range respBody { _, ok := approvedPeersMap[peer.Id] if !ok { @@ -229,7 +238,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv } // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. -func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { +func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go similarity index 99% rename from management/server/http/peers_handler_test.go rename to management/server/http/handlers/peers/peers_handler_test.go index dd49c03b8..3e3e39deb 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -1,4 +1,4 @@ -package http +package peers import ( "bytes" @@ -38,8 +38,8 @@ const ( userIDKey ctxKey = "user_id" ) -func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { - return &PeersHandler{ +func initTestMetaData(peers ...*nbpeer.Peer) *Handler { + return &Handler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { var p *nbpeer.Peer diff --git a/management/server/http/geolocation_handler_test.go b/management/server/http/handlers/policies/geolocation_handler_test.go similarity index 94% rename from management/server/http/geolocation_handler_test.go rename to management/server/http/handlers/policies/geolocation_handler_test.go index 19c916dd2..002b914ef 100644 --- a/management/server/http/geolocation_handler_test.go +++ b/management/server/http/handlers/policies/geolocation_handler_test.go @@ -1,4 +1,4 @@ -package http +package policies import ( "context" @@ -11,9 +11,9 @@ import ( "testing" "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server" "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -21,12 +21,12 @@ import ( "github.com/netbirdio/netbird/util" ) -func initGeolocationTestData(t *testing.T) *GeolocationsHandler { +func initGeolocationTestData(t *testing.T) *geolocationsHandler { t.Helper() var ( - mmdbPath = "../testdata/GeoLite2-City_20240305.mmdb" - geonamesdbPath = "../testdata/geonames_20240305.db" + mmdbPath = "../../../testdata/GeoLite2-City_20240305.mmdb" + geonamesdbPath = "../../../testdata/geonames_20240305.db" ) tempDir := t.TempDir() @@ -41,7 +41,7 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler { assert.NoError(t, err) t.Cleanup(func() { _ = geo.Stop() }) - return &GeolocationsHandler{ + return &geolocationsHandler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil @@ -114,7 +114,7 @@ func TestGetCitiesByCountry(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) router := mux.NewRouter() - router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.GetCitiesByCountry).Methods("GET") + router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -202,7 +202,7 @@ func TestGetAllCountries(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) router := mux.NewRouter() - router.HandleFunc("/api/locations/countries", geolocationHandler.GetAllCountries).Methods("GET") + router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go similarity index 72% rename from management/server/http/geolocations_handler.go rename to management/server/http/handlers/policies/geolocations_handler.go index 418228abf..e5bf3e695 100644 --- a/management/server/http/geolocations_handler.go +++ b/management/server/http/handlers/policies/geolocations_handler.go @@ -1,4 +1,4 @@ -package http +package policies import ( "net/http" @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" @@ -18,16 +19,22 @@ var ( countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$") ) -// GeolocationsHandler is a handler that returns locations. -type GeolocationsHandler struct { +// geolocationsHandler is a handler that returns locations. +type geolocationsHandler struct { accountManager server.AccountManager geolocationManager *geolocation.Geolocation claimsExtractor *jwtclaims.ClaimsExtractor } -// NewGeolocationsHandlerHandler creates a new Geolocations handler -func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *GeolocationsHandler { - return &GeolocationsHandler{ +func addLocationsEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { + locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg) + router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS") + router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS") +} + +// newGeolocationsHandlerHandler creates a new Geolocations handler +func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler { + return &geolocationsHandler{ accountManager: accountManager, geolocationManager: geolocationManager, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -37,8 +44,8 @@ func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geoloca } } -// GetAllCountries retrieves a list of all countries -func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Request) { +// getAllCountries retrieves a list of all countries +func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) { if err := l.authenticateUser(r); err != nil { util.WriteError(r.Context(), err, w) return @@ -63,8 +70,8 @@ func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Req util.WriteJSONObject(r.Context(), w, countries) } -// GetCitiesByCountry retrieves a list of cities based on the given country code -func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.Request) { +// getCitiesByCountry retrieves a list of cities based on the given country code +func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) { if err := l.authenticateUser(r); err != nil { util.WriteError(r.Context(), err, w) return @@ -96,7 +103,7 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http. util.WriteJSONObject(r.Context(), w, cities) } -func (l *GeolocationsHandler) authenticateUser(r *http.Request) error { +func (l *geolocationsHandler) authenticateUser(r *http.Request) error { claims := l.claimsExtractor.FromRequestContext(r) _, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go similarity index 84% rename from management/server/http/policies_handler.go rename to management/server/http/handlers/policies/policies_handler.go index 1497a4fea..89694cd8e 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -1,4 +1,4 @@ -package http +package policies import ( "encoding/json" @@ -8,22 +8,34 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/geolocation" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// Policies is a handler that returns policy of the account -type Policies struct { +// handler is a handler that returns policy of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPoliciesHandler creates a new Policies handler -func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Policies { - return &Policies{ +func AddEndpoints(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { + policiesHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS") + router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS") + router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS") + router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS") + router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS") + addPostureCheckEndpoint(accountManager, locationManager, authCfg, router) +} + +// newHandler creates a new policies handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -32,8 +44,8 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) * } } -// GetAllPolicies list for the account -func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { +// getAllPolicies list for the account +func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -66,8 +78,8 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, policies) } -// UpdatePolicy handles update to a policy identified by a given ID -func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { +// updatePolicy handles update to a policy identified by a given ID +func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -91,8 +103,8 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { h.savePolicy(w, r, accountID, userID, policyID) } -// CreatePolicy handles policy creation request -func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { +// createPolicy handles policy creation request +func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -104,7 +116,7 @@ func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { } // savePolicy handles policy creation and update -func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) { +func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) { var req api.PutApiPoliciesPolicyIdJSONRequestBody if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -251,8 +263,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID util.WriteJSONObject(r.Context(), w, resp) } -// DeletePolicy handles policy deletion request -func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { +// deletePolicy handles policy deletion request +func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -272,11 +284,11 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetPolicy handles a group Get request identified by ID -func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { +// getPolicy handles a group Get request identified by ID +func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go similarity index 95% rename from management/server/http/policies_handler_test.go rename to management/server/http/handlers/policies/policies_handler_test.go index ad01f70d1..af3412f89 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -1,4 +1,4 @@ -package http +package policies import ( "bytes" @@ -24,12 +24,12 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" ) -func initPoliciesTestData(policies ...*server.Policy) *Policies { +func initPoliciesTestData(policies ...*server.Policy) *handler { testPolicies := make(map[string]*server.Policy, len(policies)) for _, policy := range policies { testPolicies[policy.ID] = policy } - return &Policies{ + return &handler{ accountManager: &mock_server.MockAccountManager{ GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) { policy, ok := testPolicies[policyID] @@ -91,14 +91,14 @@ func TestPoliciesGetPolicy(t *testing.T) { requestBody io.Reader }{ { - name: "GetPolicy OK", + name: "getPolicy OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/policies/idofthepolicy", expectedStatus: http.StatusOK, }, { - name: "GetPolicy not found", + name: "getPolicy not found", requestType: http.MethodGet, requestPath: "/api/policies/notexists", expectedStatus: http.StatusNotFound, @@ -121,7 +121,7 @@ func TestPoliciesGetPolicy(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/policies/{policyId}", p.GetPolicy).Methods("GET") + router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -272,8 +272,8 @@ func TestPoliciesWritePolicy(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/policies", p.CreatePolicy).Methods("POST") - router.HandleFunc("/api/policies/{policyId}", p.UpdatePolicy).Methods("PUT") + router.HandleFunc("/api/policies", p.createPolicy).Methods("POST") + router.HandleFunc("/api/policies/{policyId}", p.updatePolicy).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go similarity index 70% rename from management/server/http/posture_checks_handler.go rename to management/server/http/handlers/policies/posture_checks_handler.go index 2c8204292..44917605b 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/handlers/policies/posture_checks_handler.go @@ -1,4 +1,4 @@ -package http +package policies import ( "encoding/json" @@ -9,22 +9,33 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" ) -// PostureChecksHandler is a handler that returns posture checks of the account. -type PostureChecksHandler struct { +// postureChecksHandler is a handler that returns posture checks of the account. +type postureChecksHandler struct { accountManager server.AccountManager geolocationManager *geolocation.Geolocation claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPostureChecksHandler creates a new PostureChecks handler -func NewPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *PostureChecksHandler { - return &PostureChecksHandler{ +func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { + postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg) + router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS") + router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS") + addLocationsEndpoint(accountManager, locationManager, authCfg, router) +} + +// newPostureChecksHandler creates a new PostureChecks handler +func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler { + return &postureChecksHandler{ accountManager: accountManager, geolocationManager: geolocationManager, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -34,8 +45,8 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa } } -// GetAllPostureChecks list for the account -func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) { +// getAllPostureChecks list for the account +func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -57,8 +68,8 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, postureChecks) } -// UpdatePostureCheck handles update to a posture check identified by a given ID -func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) { +// updatePostureCheck handles update to a posture check identified by a given ID +func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -82,8 +93,8 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http p.savePostureChecks(w, r, accountID, userID, postureChecksID) } -// CreatePostureCheck handles posture check creation request -func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) { +// createPostureCheck handles posture check creation request +func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -94,8 +105,8 @@ func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http p.savePostureChecks(w, r, accountID, userID, "") } -// GetPostureCheck handles a posture check Get request identified by ID -func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) { +// getPostureCheck handles a posture check Get request identified by ID +func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -119,8 +130,8 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse()) } -// DeletePostureCheck handles posture check deletion request -func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) { +// deletePostureCheck handles posture check deletion request +func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -140,11 +151,11 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } // savePostureChecks handles posture checks create and update -func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) { +func (p *postureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) { var ( err error req api.PostureCheckUpdate diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go similarity index 96% rename from management/server/http/posture_checks_handler_test.go rename to management/server/http/handlers/policies/posture_checks_handler_test.go index f400cec81..e9a539e45 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -1,4 +1,4 @@ -package http +package policies import ( "bytes" @@ -25,13 +25,13 @@ import ( var berlin = "Berlin" var losAngeles = "Los Angeles" -func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksHandler { +func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksHandler { testPostureChecks := make(map[string]*posture.Checks, len(postureChecks)) for _, postureCheck := range postureChecks { testPostureChecks[postureCheck.ID] = postureCheck } - return &PostureChecksHandler{ + return &postureChecksHandler{ accountManager: &mock_server.MockAccountManager{ GetPostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { p, ok := testPostureChecks[postureChecksID] @@ -147,35 +147,35 @@ func TestGetPostureCheck(t *testing.T) { requestBody io.Reader }{ { - name: "GetPostureCheck NBVersion OK", + name: "getPostureCheck NBVersion OK", expectedBody: true, id: postureCheck.ID, checkName: postureCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck OSVersion OK", + name: "getPostureCheck OSVersion OK", expectedBody: true, id: osPostureCheck.ID, checkName: osPostureCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck GeoLocation OK", + name: "getPostureCheck GeoLocation OK", expectedBody: true, id: geoPostureCheck.ID, checkName: geoPostureCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck PrivateNetwork OK", + name: "getPostureCheck PrivateNetwork OK", expectedBody: true, id: privateNetworkCheck.ID, checkName: privateNetworkCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck Not Found", + name: "getPostureCheck Not Found", id: "not-exists", expectedStatus: http.StatusNotFound, }, @@ -189,7 +189,7 @@ func TestGetPostureCheck(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/posture-checks/{postureCheckId}", p.GetPostureCheck).Methods("GET") + router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -231,7 +231,7 @@ func TestPostureCheckUpdate(t *testing.T) { requestType string requestPath string requestBody io.Reader - setupHandlerFunc func(handler *PostureChecksHandler) + setupHandlerFunc func(handler *postureChecksHandler) }{ { name: "Create Posture Checks NB version", @@ -286,7 +286,7 @@ func TestPostureCheckUpdate(t *testing.T) { }, }, }, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -427,7 +427,7 @@ func TestPostureCheckUpdate(t *testing.T) { }`)), expectedStatus: http.StatusPreconditionFailed, expectedBody: false, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -614,7 +614,7 @@ func TestPostureCheckUpdate(t *testing.T) { }, }, }, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -677,7 +677,7 @@ func TestPostureCheckUpdate(t *testing.T) { }`)), expectedStatus: http.StatusPreconditionFailed, expectedBody: false, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -842,8 +842,8 @@ func TestPostureCheckUpdate(t *testing.T) { } router := mux.NewRouter() - router.HandleFunc("/api/posture-checks", defaultHandler.CreatePostureCheck).Methods("POST") - router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.UpdatePostureCheck).Methods("PUT") + router.HandleFunc("/api/posture-checks", defaultHandler.createPostureCheck).Methods("POST") + router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.updatePostureCheck).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go similarity index 85% rename from management/server/http/routes_handler.go rename to management/server/http/handlers/routes/routes_handler.go index f44a164e2..9d420066c 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -1,4 +1,4 @@ -package http +package routes import ( "encoding/json" @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" @@ -23,15 +24,24 @@ import ( const maxDomains = 32 const failedToConvertRoute = "failed to convert route to response: %v" -// RoutesHandler is the routes handler of the account -type RoutesHandler struct { +// handler is the routes handler of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewRoutesHandler returns a new instance of RoutesHandler handler -func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *RoutesHandler { - return &RoutesHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + routesHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS") + router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS") + router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS") + router.HandleFunc("/routes/{routeId}", routesHandler.getRoute).Methods("GET", "OPTIONS") + router.HandleFunc("/routes/{routeId}", routesHandler.deleteRoute).Methods("DELETE", "OPTIONS") +} + +// newHandler returns a new instance of routes handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -40,8 +50,8 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro } } -// GetAllRoutes returns the list of routes for the account -func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { +// getAllRoutes returns the list of routes for the account +func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -67,8 +77,8 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, apiRoutes) } -// CreateRoute handles route creation request -func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { +// createRoute handles route creation request +func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -139,7 +149,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, routes) } -func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) error { +func (h *handler) validateRoute(req api.PostApiRoutesJSONRequestBody) error { if req.Network != nil && req.Domains != nil { return status.Errorf(status.InvalidArgument, "only one of 'network' or 'domains' should be provided") } @@ -164,8 +174,8 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro return nil } -// UpdateRoute handles update to a route identified by a given ID -func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { +// updateRoute handles update to a route identified by a given ID +func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -257,8 +267,8 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, routes) } -// DeleteRoute handles route deletion request -func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { +// deleteRoute handles route deletion request +func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -278,11 +288,11 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetRoute handles a route Get request identified by ID -func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { +// getRoute handles a route Get request identified by ID +func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go similarity index 98% rename from management/server/http/routes_handler_test.go rename to management/server/http/handlers/routes/routes_handler_test.go index 83bd7004d..a25c899c9 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -1,4 +1,4 @@ -package http +package routes import ( "bytes" @@ -87,8 +87,8 @@ var testingAccount = &server.Account{ }, } -func initRoutesTestData() *RoutesHandler { - return &RoutesHandler{ +func initRoutesTestData() *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) { if routeID == existingRouteID { @@ -152,7 +152,7 @@ func initRoutesTestData() *RoutesHandler { return nil }, GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { - //return testingAccount, testingAccount.Users["test_user"], nil + // return testingAccount, testingAccount.Users["test_user"], nil return testingAccount.Id, testingAccount.Users["test_user"].Id, nil }, }, @@ -521,10 +521,10 @@ func TestRoutesHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/routes/{routeId}", p.GetRoute).Methods("GET") - router.HandleFunc("/api/routes/{routeId}", p.DeleteRoute).Methods("DELETE") - router.HandleFunc("/api/routes", p.CreateRoute).Methods("POST") - router.HandleFunc("/api/routes/{routeId}", p.UpdateRoute).Methods("PUT") + router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET") + router.HandleFunc("/api/routes/{routeId}", p.deleteRoute).Methods("DELETE") + router.HandleFunc("/api/routes", p.createRoute).Methods("POST") + router.HandleFunc("/api/routes/{routeId}", p.updateRoute).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go similarity index 78% rename from management/server/http/setupkeys_handler.go rename to management/server/http/handlers/setup_keys/setupkeys_handler.go index 9ba5977bb..9432d5549 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -1,4 +1,4 @@ -package http +package setup_keys import ( "context" @@ -10,20 +10,30 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// SetupKeysHandler is a handler that returns a list of setup keys of the account -type SetupKeysHandler struct { +// handler is a handler that returns a list of setup keys of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewSetupKeysHandler creates a new SetupKeysHandler HTTP handler -func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) *SetupKeysHandler { - return &SetupKeysHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + keysHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS") + router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", keysHandler.updateSetupKey).Methods("PUT", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", keysHandler.deleteSetupKey).Methods("DELETE", "OPTIONS") +} + +// newHandler creates a new setup key handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -32,8 +42,8 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) } } -// CreateSetupKey is a POST requests that creates a new SetupKey -func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) { +// createSetupKey is a POST requests that creates a new SetupKey +func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -89,8 +99,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request util.WriteJSONObject(r.Context(), w, apiSetupKeys) } -// GetSetupKey is a GET request to get a SetupKey by ID -func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { +// getSetupKey is a GET request to get a SetupKey by ID +func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -114,8 +124,8 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { writeSuccess(r.Context(), w, key) } -// UpdateSetupKey is a PUT request to update server.SetupKey -func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) { +// updateSetupKey is a PUT request to update server.SetupKey +func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -155,8 +165,8 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request writeSuccess(r.Context(), w, newKey) } -// GetAllSetupKeys is a GET request that returns a list of SetupKey -func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) { +// getAllSetupKeys is a GET request that returns a list of SetupKey +func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -178,7 +188,7 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques util.WriteJSONObject(r.Context(), w, apiSetupKeys) } -func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request) { +func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -199,7 +209,7 @@ func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) { diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go similarity index 95% rename from management/server/http/setupkeys_handler_test.go rename to management/server/http/handlers/setup_keys/setupkeys_handler_test.go index 09256d0ea..516a2ab8b 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go @@ -1,4 +1,4 @@ -package http +package setup_keys import ( "bytes" @@ -26,12 +26,13 @@ const ( newSetupKeyName = "New Setup Key" updatedSetupKeyName = "KKKey" notFoundSetupKeyID = "notFoundSetupKeyID" + testAccountID = "test_id" ) func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey, user *server.User, -) *SetupKeysHandler { - return &SetupKeysHandler{ +) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil @@ -178,11 +179,11 @@ func TestSetupKeysHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/setup-keys", handler.GetAllSetupKeys).Methods("GET", "OPTIONS") - router.HandleFunc("/api/setup-keys", handler.CreateSetupKey).Methods("POST", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.GetSetupKey).Methods("GET", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.DeleteSetupKey).Methods("DELETE", "OPTIONS") + router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS") + router.HandleFunc("/api/setup-keys", handler.createSetupKey).Methods("POST", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", handler.getSetupKey).Methods("GET", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", handler.updateSetupKey).Methods("PUT", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", handler.deleteSetupKey).Methods("DELETE", "OPTIONS") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/pat_handler.go b/management/server/http/handlers/users/pat_handler.go similarity index 75% rename from management/server/http/pat_handler.go rename to management/server/http/handlers/users/pat_handler.go index dfa9563e3..2caf98ad8 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/handlers/users/pat_handler.go @@ -1,4 +1,4 @@ -package http +package users import ( "encoding/json" @@ -9,20 +9,29 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// PATHandler is the nameserver group handler of the account -type PATHandler struct { +// patHandler is the nameserver group handler of the account +type patHandler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPATsHandler creates a new PATHandler HTTP handler -func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATHandler { - return &PATHandler{ +func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + tokenHandler := newPATsHandler(accountManager, authCfg) + router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS") +} + +// newPATsHandler creates a new patHandler HTTP handler +func newPATsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *patHandler { + return &patHandler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -31,8 +40,8 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH } } -// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user -func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { +// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user +func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -61,8 +70,8 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, patResponse) } -// GetToken is HTTP GET handler that returns a personal access token for the given user -func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { +// getToken is HTTP GET handler that returns a personal access token for the given user +func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -92,8 +101,8 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toPATResponse(pat)) } -// CreateToken is HTTP POST handler that creates a personal access token for the given user -func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { +// createToken is HTTP POST handler that creates a personal access token for the given user +func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -124,8 +133,8 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toPATGeneratedResponse(pat)) } -// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user -func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { +// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user +func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -152,7 +161,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { diff --git a/management/server/http/pat_handler_test.go b/management/server/http/handlers/users/pat_handler_test.go similarity index 96% rename from management/server/http/pat_handler_test.go rename to management/server/http/handlers/users/pat_handler_test.go index c28228a50..ef6fb973e 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/handlers/users/pat_handler_test.go @@ -1,4 +1,4 @@ -package http +package users import ( "bytes" @@ -61,8 +61,8 @@ var testAccount = &server.Account{ }, } -func initPATTestData() *PATHandler { - return &PATHandler{ +func initPATTestData() *patHandler { + return &patHandler{ accountManager: &mock_server.MockAccountManager{ CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { if accountID != existingAccountID { @@ -186,10 +186,10 @@ func TestTokenHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/users/{userId}/tokens", p.GetAllTokens).Methods("GET") - router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.GetToken).Methods("GET") - router.HandleFunc("/api/users/{userId}/tokens", p.CreateToken).Methods("POST") - router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.DeleteToken).Methods("DELETE") + router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/users_handler.go b/management/server/http/handlers/users/users_handler.go similarity index 80% rename from management/server/http/users_handler.go rename to management/server/http/handlers/users/users_handler.go index 6e151a0da..c843bc52b 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -1,4 +1,4 @@ -package http +package users import ( "encoding/json" @@ -9,6 +9,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" @@ -16,15 +17,25 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" ) -// UsersHandler is a handler that returns users of the account -type UsersHandler struct { +// handler is a handler that returns users of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewUsersHandler creates a new UsersHandler HTTP handler -func NewUsersHandler(accountManager server.AccountManager, authCfg AuthCfg) *UsersHandler { - return &UsersHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + userHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS") + router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS") + router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS") + addUsersTokensEndpoint(accountManager, authCfg, router) +} + +// newHandler creates a new UsersHandler HTTP handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -33,8 +44,8 @@ func NewUsersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Use } } -// UpdateUser is a PUT requests to update User data -func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { +// updateUser is a PUT requests to update User data +func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPut { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -94,8 +105,8 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) } -// DeleteUser is a DELETE request to delete a user -func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { +// deleteUser is a DELETE request to delete a user +func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodDelete { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -121,11 +132,11 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// CreateUser creates a User in the system with a status "invited" (effectively this is a user invite). -func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { +// createUser creates a User in the system with a status "invited" (effectively this is a user invite). +func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -175,9 +186,9 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) } -// GetAllUsers returns a list of users of the account this user belongs to. +// getAllUsers returns a list of users of the account this user belongs to. // It also gathers additional user data (like email and name) from the IDP manager. -func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { +func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -222,9 +233,9 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, users) } -// InviteUser resend invitations to users who haven't activated their accounts, +// inviteUser resend invitations to users who haven't activated their accounts, // prior to the expiration period. -func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -250,7 +261,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { diff --git a/management/server/http/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go similarity index 97% rename from management/server/http/users_handler_test.go rename to management/server/http/handlers/users/users_handler_test.go index f3d989da1..6f6a91236 100644 --- a/management/server/http/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -1,4 +1,4 @@ -package http +package users import ( "bytes" @@ -61,8 +61,8 @@ var usersTestAccount = &server.Account{ }, } -func initUsersTestData() *UsersHandler { - return &UsersHandler{ +func initUsersTestData() *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return usersTestAccount.Id, claims.UserId, nil @@ -147,7 +147,7 @@ func TestGetUsers(t *testing.T) { requestPath string expectedUserIDs []string }{ - {name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}}, + {name: "getAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}}, {name: "GetOnlyServiceUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=true", expectedStatus: http.StatusOK, expectedUserIDs: []string{serviceUserID}}, {name: "GetOnlyRegularUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=false", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID}}, } @@ -159,7 +159,7 @@ func TestGetUsers(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - userHandler.GetAllUsers(recorder, req) + userHandler.getAllUsers(recorder, req) res := recorder.Result() defer res.Body.Close() @@ -265,7 +265,7 @@ func TestUpdateUser(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/users/{userId}", userHandler.UpdateUser).Methods("PUT") + router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -356,7 +356,7 @@ func TestCreateUser(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) rr := httptest.NewRecorder() - userHandler.CreateUser(rr, req) + userHandler.createUser(rr, req) res := rr.Result() defer res.Body.Close() @@ -401,7 +401,7 @@ func TestInviteUser(t *testing.T) { req = mux.SetURLVars(req, tc.requestVars) rr := httptest.NewRecorder() - userHandler.InviteUser(rr, req) + userHandler.inviteUser(rr, req) res := rr.Result() defer res.Body.Close() @@ -454,7 +454,7 @@ func TestDeleteUser(t *testing.T) { req = mux.SetURLVars(req, tc.requestVars) rr := httptest.NewRecorder() - userHandler.DeleteUser(rr, req) + userHandler.deleteUser(rr, req) res := rr.Result() defer res.Body.Close() diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index 603c1c696..3d7eed498 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -14,6 +14,10 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) +// EmptyObject is an empty struct used to return empty JSON object +type EmptyObject struct { +} + type ErrorResponse struct { Message string `json:"message"` Code int `json:"code"` diff --git a/management/server/peer.go b/management/server/peer.go index 761aa39a2..ba211be96 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -740,7 +740,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) // it means that the client has already checked if it needs login and had been through the SSO flow // so, we can skip this check and directly proceed with the login if login.UserID == "" { - log.Info("Peer needs login") err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login) if err != nil { return nil, nil, nil, err diff --git a/util/semaphore-group/semaphore_group.go b/util/semaphore-group/semaphore_group.go new file mode 100644 index 000000000..ad74e1bfc --- /dev/null +++ b/util/semaphore-group/semaphore_group.go @@ -0,0 +1,48 @@ +package semaphoregroup + +import ( + "context" + "sync" +) + +// SemaphoreGroup is a custom type that combines sync.WaitGroup and a semaphore. +type SemaphoreGroup struct { + waitGroup sync.WaitGroup + semaphore chan struct{} +} + +// NewSemaphoreGroup creates a new SemaphoreGroup with the specified semaphore limit. +func NewSemaphoreGroup(limit int) *SemaphoreGroup { + return &SemaphoreGroup{ + semaphore: make(chan struct{}, limit), + } +} + +// Add increments the internal WaitGroup counter and acquires a semaphore slot. +func (sg *SemaphoreGroup) Add(ctx context.Context) { + sg.waitGroup.Add(1) + + // Acquire semaphore slot + select { + case <-ctx.Done(): + return + case sg.semaphore <- struct{}{}: + } +} + +// Done decrements the internal WaitGroup counter and releases a semaphore slot. +func (sg *SemaphoreGroup) Done(ctx context.Context) { + sg.waitGroup.Done() + + // Release semaphore slot + select { + case <-ctx.Done(): + return + case <-sg.semaphore: + } +} + +// Wait waits until the internal WaitGroup counter is zero. +func (sg *SemaphoreGroup) Wait() { + sg.waitGroup.Wait() +} diff --git a/util/semaphore-group/semaphore_group_test.go b/util/semaphore-group/semaphore_group_test.go new file mode 100644 index 000000000..d4491cf77 --- /dev/null +++ b/util/semaphore-group/semaphore_group_test.go @@ -0,0 +1,66 @@ +package semaphoregroup + +import ( + "context" + "testing" + "time" +) + +func TestSemaphoreGroup(t *testing.T) { + semGroup := NewSemaphoreGroup(2) + + for i := 0; i < 5; i++ { + semGroup.Add(context.Background()) + go func(id int) { + defer semGroup.Done(context.Background()) + + got := len(semGroup.semaphore) + if got == 0 { + t.Errorf("Expected semaphore length > 0 , got 0") + } + + time.Sleep(time.Millisecond) + t.Logf("Goroutine %d is running\n", id) + }(i) + } + + semGroup.Wait() + + want := 0 + got := len(semGroup.semaphore) + if got != want { + t.Errorf("Expected semaphore length %d, got %d", want, got) + } +} + +func TestSemaphoreGroupContext(t *testing.T) { + semGroup := NewSemaphoreGroup(1) + semGroup.Add(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + t.Cleanup(cancel) + rChan := make(chan struct{}) + + go func() { + semGroup.Add(ctx) + rChan <- struct{}{} + }() + select { + case <-rChan: + case <-time.NewTimer(2 * time.Second).C: + t.Error("Adding to semaphore group should not block when context is not done") + } + + semGroup.Done(context.Background()) + + ctxDone, cancelDone := context.WithTimeout(context.Background(), 1*time.Second) + t.Cleanup(cancelDone) + go func() { + semGroup.Done(ctxDone) + rChan <- struct{}{} + }() + select { + case <-rChan: + case <-time.NewTimer(2 * time.Second).C: + t.Error("Releasing from semaphore group should not block when context is not done") + } +}