1
0
mirror of https://github.com/netbirdio/netbird.git synced 2025-03-11 13:18:12 +01:00

Add account usage logic ()

---------

Co-authored-by: Yury Gargay <yury.gargay@gmail.com>
This commit is contained in:
Viktor Liu 2024-02-22 12:27:08 +01:00 committed by GitHub
parent e18bf565a2
commit b7a6cbfaa5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 576 additions and 30 deletions

4
go.mod
View File

@ -57,8 +57,8 @@ require (
github.com/miekg/dns v1.1.43 github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/additions v0.0.0-20240118163419-8a7c87accb22 github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552
github.com/netbirdio/management-integrations/integrations v0.0.0-20240118163419-8a7c87accb22 github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible

8
go.sum
View File

@ -376,10 +376,10 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRW
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc= github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
github.com/netbirdio/management-integrations/additions v0.0.0-20240118163419-8a7c87accb22 h1:XTiNnVB6OEwung8WIiGJNzOTLVefuSzAA/cu+6Sst8A= github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552 h1:yzcQKizAK9YufCHMMCIsr467Dw/OU/4xyHbWizGb1E4=
github.com/netbirdio/management-integrations/additions v0.0.0-20240118163419-8a7c87accb22/go.mod h1:31FhBNvQ+riHEIu6LSTmqr8IeuSIsGfQffqV4LFmbwA= github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:31FhBNvQ+riHEIu6LSTmqr8IeuSIsGfQffqV4LFmbwA=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240118163419-8a7c87accb22 h1:FNc4p8RS/gFm5jlmvUFWC4/5YxPDWejYyqEBVziFZwo= github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552 h1:OFlzVZtkXCoJsfDKrMigFpuad8ZXTm8epq6x27K0irA=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240118163419-8a7c87accb22/go.mod h1:B0nMS3es77gOvPYhc0K91fAzTkQLi/jRq5TffUN3klM= github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:B0nMS3es77gOvPYhc0K91fAzTkQLi/jRq5TffUN3klM=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM= github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=

View File

@ -242,7 +242,10 @@ var (
UserIDClaim: config.HttpConfig.AuthUserIDClaim, UserIDClaim: config.HttpConfig.AuthUserIDClaim,
KeysLocation: config.HttpConfig.AuthKeysLocation, KeysLocation: config.HttpConfig.AuthKeysLocation,
} }
httpAPIHandler, err := httpapi.APIHandler(accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg)
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg)
if err != nil { if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err) return fmt.Errorf("failed creating HTTP API handler: %v", err)
} }
@ -264,8 +267,6 @@ var (
} }
if !disableMetrics { if !disableMetrics {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
idpManager := "disabled" idpManager := "disabled"
if config.IdpManagerConfig != nil && config.IdpManagerConfig.ManagerType != "" { if config.IdpManagerConfig != nil && config.IdpManagerConfig.ManagerType != "" {
idpManager = config.IdpManagerConfig.ManagerType idpManager = config.IdpManagerConfig.ManagerType

View File

@ -72,6 +72,7 @@ type AccountManager interface {
CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
DeleteAccount(accountID, userID string) error DeleteAccount(accountID, userID string) error
GetUsage(ctx context.Context, accountID string, start time.Time, end time.Time) (*AccountUsageStats, error)
MarkPATUsed(tokenID string) error MarkPATUsed(tokenID string) error
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(accountID string) ([]*User, error) ListUsers(accountID string) ([]*User, error)
@ -110,7 +111,7 @@ type AccountManager interface {
DeleteNameServerGroup(accountID, nsGroupID, userID string) error DeleteNameServerGroup(accountID, nsGroupID, userID string) error
ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
GetDNSDomain() string GetDNSDomain() string
StoreEvent(initiatorID, targetID, accountID string, activityID activity.Activity, meta map[string]any) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEvents(accountID, userID string) ([]*activity.Event, error) GetEvents(accountID, userID string) ([]*activity.Event, error)
GetDNSSettings(accountID string, userID string) (*DNSSettings, error) GetDNSSettings(accountID string, userID string) (*DNSSettings, error)
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
@ -230,6 +231,14 @@ type Account struct {
RulesG []Rule `json:"-" gorm:"-"` RulesG []Rule `json:"-" gorm:"-"`
} }
// AccountUsageStats represents the current usage statistics for an account
type AccountUsageStats struct {
ActiveUsers int64 `json:"active_users"`
TotalUsers int64 `json:"total_users"`
ActivePeers int64 `json:"active_peers"`
TotalPeers int64 `json:"total_peers"`
}
type UserInfo struct { type UserInfo struct {
ID string `json:"id"` ID string `json:"id"`
Email string `json:"email"` Email string `json:"email"`
@ -1105,8 +1114,20 @@ func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error {
return nil return nil
} }
// GetUsage returns the usage stats for the given account.
// This cannot be used to calculate usage stats for a period in the past as it relies on peers' last seen time.
func (am *DefaultAccountManager) GetUsage(ctx context.Context, accountID string, start time.Time, end time.Time) (*AccountUsageStats, error) {
usageStats, err := am.Store.CalculateUsageStats(ctx, accountID, start, end)
if err != nil {
return nil, fmt.Errorf("failed to calculate usage stats: %w", err)
}
return usageStats, nil
}
// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and // GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
// userID doesn't have an account associated with it, one account is created // userID doesn't have an account associated with it, one account is created
// domain is used to create a new account if no account is found
func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) { func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) {
if accountID != "" { if accountID != "" {
return am.Store.GetAccount(accountID) return am.Store.GetAccount(accountID)
@ -1791,7 +1812,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut
return nil return nil
} }
// addAllGroup to account object if it doesn't exists // addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error { func addAllGroup(account *Account) error {
if len(account.Groups) == 0 { if len(account.Groups) == 0 {
allGroup := &Group{ allGroup := &Group{

View File

@ -1,12 +1,14 @@
package activity package activity
import "maps"
// Activity that triggered an Event // Activity that triggered an Event
type Activity int type Activity int
// Code is an activity string representation // Code is an activity string representation
type Code struct { type Code struct {
message string Message string
code string Code string
} }
const ( const (
@ -207,7 +209,7 @@ var activityMap = map[Activity]Code{
// StringCode returns a string code of the activity // StringCode returns a string code of the activity
func (a Activity) StringCode() string { func (a Activity) StringCode() string {
if code, ok := activityMap[a]; ok { if code, ok := activityMap[a]; ok {
return code.code return code.Code
} }
return "UNKNOWN_ACTIVITY" return "UNKNOWN_ACTIVITY"
} }
@ -215,7 +217,12 @@ func (a Activity) StringCode() string {
// Message returns a string representation of an activity // Message returns a string representation of an activity
func (a Activity) Message() string { func (a Activity) Message() string {
if code, ok := activityMap[a]; ok { if code, ok := activityMap[a]; ok {
return code.message return code.Message
} }
return "UNKNOWN_ACTIVITY" return "UNKNOWN_ACTIVITY"
} }
// RegisterActivityMap adds new codes to the activity map
func RegisterActivityMap(codes map[Activity]Code) {
maps.Copy(activityMap, codes)
}

View File

@ -8,12 +8,18 @@ const (
SystemInitiator = "sys" SystemInitiator = "sys"
) )
// ActivityDescriber is an interface that describes an activity
type ActivityDescriber interface {
StringCode() string
Message() string
}
// Event represents a network/system activity event. // Event represents a network/system activity event.
type Event struct { type Event struct {
// Timestamp of the event // Timestamp of the event
Timestamp time.Time Timestamp time.Time
// Activity that was performed during the event // Activity that was performed during the event
Activity Activity Activity ActivityDescriber
// ID of the event (can be empty, meaning that it wasn't yet generated) // ID of the event (can be empty, meaning that it wasn't yet generated)
ID uint64 ID uint64
// InitiatorID is the ID of an object that initiated the event (e.g., a user) // InitiatorID is the ID of an object that initiated the event (e.g., a user)

View File

@ -54,8 +54,7 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit
return filtered, nil return filtered, nil
} }
func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.Activity, func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
meta map[string]any) {
go func() { go func() {
_, err := am.eventStore.Save(&activity.Event{ _, err := am.eventStore.Save(&activity.Event{

View File

@ -1,6 +1,8 @@
package server package server
import ( import (
"context"
"fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@ -662,3 +664,40 @@ func (s *FileStore) Close() error {
func (s *FileStore) GetStoreEngine() StoreEngine { func (s *FileStore) GetStoreEngine() StoreEngine {
return FileStoreEngine return FileStoreEngine
} }
// CalculateUsageStats returns the usage stats for an account
// start and end are inclusive.
func (s *FileStore) CalculateUsageStats(_ context.Context, accountID string, start time.Time, end time.Time) (*AccountUsageStats, error) {
s.mux.Lock()
defer s.mux.Unlock()
account, exists := s.Accounts[accountID]
if !exists {
return nil, fmt.Errorf("account not found")
}
stats := &AccountUsageStats{
TotalUsers: 0,
TotalPeers: int64(len(account.Peers)),
}
for _, user := range account.Users {
if !user.IsServiceUser {
stats.TotalUsers++
}
}
activeUsers := make(map[string]bool)
for _, peer := range account.Peers {
lastSeen := peer.Status.LastSeen
if lastSeen.Compare(start) >= 0 && lastSeen.Compare(end) <= 0 {
if _, exists := account.Users[peer.UserID]; exists && !activeUsers[peer.UserID] {
activeUsers[peer.UserID] = true
stats.ActiveUsers++
}
stats.ActivePeers++
}
}
return stats, nil
}

View File

@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"crypto/sha256" "crypto/sha256"
"net" "net"
"path/filepath" "path/filepath"
@ -657,3 +658,32 @@ func newStore(t *testing.T) *FileStore {
return store return store
} }
func TestFileStore_CalculateUsageStats(t *testing.T) {
storeDir := t.TempDir()
err := util.CopyFileContents("testdata/store_stats.json", filepath.Join(storeDir, "store.json"))
require.NoError(t, err)
store, err := NewFileStore(storeDir, nil)
require.NoError(t, err)
startDate := time.Date(2024, time.February, 1, 0, 0, 0, 0, time.UTC)
endDate := startDate.AddDate(0, 1, 0).Add(-time.Nanosecond)
stats1, err := store.CalculateUsageStats(context.TODO(), "account-1", startDate, endDate)
require.NoError(t, err)
assert.Equal(t, int64(2), stats1.ActiveUsers)
assert.Equal(t, int64(4), stats1.TotalUsers)
assert.Equal(t, int64(3), stats1.ActivePeers)
assert.Equal(t, int64(7), stats1.TotalPeers)
stats2, err := store.CalculateUsageStats(context.TODO(), "account-2", startDate, endDate)
require.NoError(t, err)
assert.Equal(t, int64(1), stats2.ActiveUsers)
assert.Equal(t, int64(2), stats2.TotalUsers)
assert.Equal(t, int64(1), stats2.ActivePeers)
assert.Equal(t, int64(2), stats2.TotalPeers)
}

View File

@ -1,6 +1,8 @@
package http package http
import ( import (
"context"
"fmt"
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -15,6 +17,8 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
) )
const apiPrefix = "/api"
// AuthCfg contains parameters for authentication middleware // AuthCfg contains parameters for authentication middleware
type AuthCfg struct { type AuthCfg struct {
Issuer string Issuer string
@ -35,7 +39,7 @@ type emptyObject struct {
} }
// APIHandler creates the Management service HTTP API handler registering all the available endpoints. // APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor( claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
@ -61,7 +65,8 @@ func APIHandler(accountManager s.AccountManager, LocationManager *geolocation.Ge
rootRouter := mux.NewRouter() rootRouter := mux.NewRouter()
metricsMiddleware := appMetrics.HTTPMiddleware() metricsMiddleware := appMetrics.HTTPMiddleware()
router := rootRouter.PathPrefix("/api").Subrouter() prefix := apiPrefix
router := rootRouter.PathPrefix(prefix).Subrouter()
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler) router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)
api := apiHandler{ api := apiHandler{
@ -71,7 +76,10 @@ func APIHandler(accountManager s.AccountManager, LocationManager *geolocation.Ge
AuthCfg: authCfg, AuthCfg: authCfg,
} }
integrations.RegisterHandlers(api.Router, accountManager, claimsExtractor) if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}
api.addAccountsEndpoint() api.addAccountsEndpoint()
api.addPeersEndpoint() api.addPeersEndpoint()
api.addUsersEndpoint() api.addUsersEndpoint()

View File

@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@ -36,9 +37,13 @@ func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessCont
var tokenPathRegexp = regexp.MustCompile(`^.*/api/users/.*/tokens.*$`) var tokenPathRegexp = regexp.MustCompile(`^.*/api/users/.*/tokens.*$`)
// Handler method of the middleware which forbids all modify requests for non admin users // Handler method of the middleware which forbids all modify requests for non admin users
// It also adds
func (a *AccessControl) Handler(h http.Handler) http.Handler { func (a *AccessControl) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if bypass.ShouldBypass(r.URL.Path, h, w, r) {
return
}
claims := a.claimsExtract.FromRequestContext(r) claims := a.claimsExtract.FromRequestContext(r)
user, err := a.getUser(claims) user, err := a.getUser(claims)

View File

@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
@ -66,6 +67,11 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT // Handler method of the middleware which authenticates a user either by JWT claims or by PAT
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if bypass.ShouldBypass(r.URL.Path, h, w, r) {
return
}
auth := strings.Split(r.Header.Get("Authorization"), " ") auth := strings.Split(r.Header.Get("Authorization"), " ")
authType := strings.ToLower(auth[0]) authType := strings.ToLower(auth[0])

View File

@ -10,6 +10,7 @@ import (
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
) )
@ -88,39 +89,68 @@ func mockCheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error
func TestAuthMiddleware_Handler(t *testing.T) { func TestAuthMiddleware_Handler(t *testing.T) {
tt := []struct { tt := []struct {
name string name string
path string
authHeader string authHeader string
expectedStatusCode int expectedStatusCode int
shouldBypassAuth bool
}{ }{
{ {
name: "Valid PAT Token", name: "Valid PAT Token",
path: "/test",
authHeader: "Token " + PAT, authHeader: "Token " + PAT,
expectedStatusCode: 200, expectedStatusCode: 200,
}, },
{ {
name: "Invalid PAT Token", name: "Invalid PAT Token",
path: "/test",
authHeader: "Token " + wrongToken, authHeader: "Token " + wrongToken,
expectedStatusCode: 401, expectedStatusCode: 401,
}, },
{ {
name: "Fallback to PAT Token", name: "Fallback to PAT Token",
path: "/test",
authHeader: "Bearer " + PAT, authHeader: "Bearer " + PAT,
expectedStatusCode: 200, expectedStatusCode: 200,
}, },
{ {
name: "Valid JWT Token", name: "Valid JWT Token",
path: "/test",
authHeader: "Bearer " + JWT, authHeader: "Bearer " + JWT,
expectedStatusCode: 200, expectedStatusCode: 200,
}, },
{ {
name: "Invalid JWT Token", name: "Invalid JWT Token",
path: "/test",
authHeader: "Bearer " + wrongToken, authHeader: "Bearer " + wrongToken,
expectedStatusCode: 401, expectedStatusCode: 401,
}, },
{ {
name: "Basic Auth", name: "Basic Auth",
path: "/test",
authHeader: "Basic " + PAT, authHeader: "Basic " + PAT,
expectedStatusCode: 401, expectedStatusCode: 401,
}, },
{
name: "Webhook Path Bypass",
path: "/webhook",
authHeader: "",
expectedStatusCode: 200,
shouldBypassAuth: true,
},
{
name: "Webhook Path Bypass with Subpath",
path: "/webhook/test",
authHeader: "",
expectedStatusCode: 200,
shouldBypassAuth: true,
},
{
name: "Different Webhook Path",
path: "/webhooktest",
authHeader: "",
expectedStatusCode: 401,
shouldBypassAuth: false,
},
} }
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -146,7 +176,11 @@ func TestAuthMiddleware_Handler(t *testing.T) {
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "http://testing", nil) if tc.shouldBypassAuth {
bypass.AddBypassPath(tc.path)
}
req := httptest.NewRequest("GET", "http://testing"+tc.path, nil)
req.Header.Set("Authorization", tc.authHeader) req.Header.Set("Authorization", tc.authHeader)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
@ -159,5 +193,4 @@ func TestAuthMiddleware_Handler(t *testing.T) {
} }
}) })
} }
} }

View File

@ -0,0 +1,39 @@
package bypass
import (
"net/http"
"sync"
)
var byPassMutex sync.RWMutex
// bypassPaths is a set of paths that should bypass middleware.
var bypassPaths = make(map[string]struct{})
// AddBypassPath adds an exact path to the list of paths that bypass middleware.
func AddBypassPath(path string) {
byPassMutex.Lock()
defer byPassMutex.Unlock()
bypassPaths[path] = struct{}{}
}
// RemovePath removes a path from the list of paths that bypass middleware.
func RemovePath(path string) {
byPassMutex.Lock()
defer byPassMutex.Unlock()
delete(bypassPaths, path)
}
// ShouldBypass checks if the request path is one of the auth bypass paths and returns true if the middleware should be bypassed.
// This can be used to bypass authz/authn middlewares for certain paths, such as webhooks that implement their own authentication.
func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r *http.Request) bool {
byPassMutex.RLock()
defer byPassMutex.RUnlock()
if _, ok := bypassPaths[requestPath]; ok {
h.ServeHTTP(w, r)
return true
}
return false
}

View File

@ -0,0 +1,103 @@
package bypass_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
)
func TestAuthBypass(t *testing.T) {
dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
pathToAdd string
pathToRemove string
testPath string
expectBypass bool
expectHTTPCode int
}{
{
name: "Path added to bypass",
pathToAdd: "/bypass",
testPath: "/bypass",
expectBypass: true,
expectHTTPCode: http.StatusOK,
},
{
name: "Path not added to bypass",
testPath: "/no-bypass",
expectBypass: false,
expectHTTPCode: http.StatusOK,
},
{
name: "Path removed from bypass",
pathToAdd: "/remove-bypass",
pathToRemove: "/remove-bypass",
testPath: "/remove-bypass",
expectBypass: false,
expectHTTPCode: http.StatusOK,
},
{
name: "Exact path matches bypass",
pathToAdd: "/webhook",
testPath: "/webhook",
expectBypass: true,
expectHTTPCode: http.StatusOK,
},
{
name: "Subpath does not match bypass",
pathToAdd: "/webhook",
testPath: "/webhook/extra",
expectBypass: false,
expectHTTPCode: http.StatusOK,
},
{
name: "Similar path does not match bypass",
pathToAdd: "/webhook",
testPath: "/webhooking",
expectBypass: false,
expectHTTPCode: http.StatusOK,
},
{
name: "Prefix path does not match bypass",
pathToAdd: "/webhook",
testPath: "/web",
expectBypass: false,
expectHTTPCode: http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if tc.pathToAdd != "" {
bypass.AddBypassPath(tc.pathToAdd)
defer bypass.RemovePath(tc.pathToAdd)
}
if tc.pathToRemove != "" {
bypass.RemovePath(tc.pathToRemove)
}
request, err := http.NewRequest("GET", tc.testPath, nil)
require.NoError(t, err, "Creating request should not fail")
recorder := httptest.NewRecorder()
bypassed := bypass.ShouldBypass(tc.testPath, dummyHandler, recorder, request)
assert.Equal(t, tc.expectBypass, bypassed, "Bypass check did not match expectation")
if tc.expectBypass {
assert.Equal(t, tc.expectHTTPCode, recorder.Code, "HTTP status code did not match expectation for bypassed path")
}
})
}
}

View File

@ -1,6 +1,7 @@
package mock_server package mock_server
import ( import (
"context"
"net" "net"
"time" "time"
@ -75,7 +76,7 @@ type MockAccountManager struct {
CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error
DeleteAccountFunc func(accountID, userID string) error DeleteAccountFunc func(accountID, userID string) error
GetDNSDomainFunc func() string GetDNSDomainFunc func() string
StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.Activity, meta map[string]any) StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error) GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error) GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
@ -91,6 +92,7 @@ type MockAccountManager struct {
SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error
DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error) ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error)
GetUsageFunc func(ctx context.Context, accountID string, start, end time.Time) (*server.AccountUsageStats, error)
} }
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
@ -646,7 +648,7 @@ func (am *MockAccountManager) GetAllConnectedPeers() (map[string]struct{}, error
return nil, status.Errorf(codes.Unimplemented, "method GetAllConnectedPeers is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetAllConnectedPeers is not implemented")
} }
// HasconnectedChannel mocks HasConnectedChannel of the AccountManager interface // HasConnectedChannel mocks HasConnectedChannel of the AccountManager interface
func (am *MockAccountManager) HasConnectedChannel(peerID string) bool { func (am *MockAccountManager) HasConnectedChannel(peerID string) bool {
if am.HasConnectedChannelFunc != nil { if am.HasConnectedChannelFunc != nil {
return am.HasConnectedChannelFunc(peerID) return am.HasConnectedChannelFunc(peerID)
@ -655,7 +657,7 @@ func (am *MockAccountManager) HasConnectedChannel(peerID string) bool {
} }
// StoreEvent mocks StoreEvent of the AccountManager interface // StoreEvent mocks StoreEvent of the AccountManager interface
func (am *MockAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.Activity, meta map[string]any) { func (am *MockAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
if am.StoreEventFunc != nil { if am.StoreEventFunc != nil {
am.StoreEventFunc(initiatorID, targetID, accountID, activityID, meta) am.StoreEventFunc(initiatorID, targetID, accountID, activityID, meta)
} }
@ -702,3 +704,11 @@ func (am *MockAccountManager) ListPostureChecks(accountID, userID string) ([]*po
} }
return nil, status.Errorf(codes.Unimplemented, "method ListPostureChecks is not implemented") return nil, status.Errorf(codes.Unimplemented, "method ListPostureChecks is not implemented")
} }
// GetUsage mocks GetCurrentUsage of the AccountManager interface
func (am *MockAccountManager) GetUsage(ctx context.Context, accountID string, start time.Time, end time.Time) (*server.AccountUsageStats, error) {
if am.GetUsageFunc != nil {
return am.GetUsageFunc(ctx, accountID, start, end)
}
return nil, status.Errorf(codes.Unimplemented, "method GetUsage is not implemented")
}

View File

@ -1,6 +1,8 @@
package server package server
import ( import (
"context"
"fmt"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
@ -483,11 +485,11 @@ func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time
return s.db.Save(user).Error return s.db.Save(user).Error
} }
// Close is noop in Sqlite // Close closes the underlying DB connection
func (s *SqliteStore) Close() error { func (s *SqliteStore) Close() error {
sql, err := s.db.DB() sql, err := s.db.DB()
if err != nil { if err != nil {
return err return fmt.Errorf("get db: %w", err)
} }
return sql.Close() return sql.Close()
} }
@ -496,3 +498,48 @@ func (s *SqliteStore) Close() error {
func (s *SqliteStore) GetStoreEngine() StoreEngine { func (s *SqliteStore) GetStoreEngine() StoreEngine {
return SqliteStoreEngine return SqliteStoreEngine
} }
// CalculateUsageStats returns the usage stats for an account
// start and end are inclusive.
func (s *SqliteStore) CalculateUsageStats(ctx context.Context, accountID string, start time.Time, end time.Time) (*AccountUsageStats, error) {
stats := &AccountUsageStats{}
err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := tx.Model(&nbpeer.Peer{}).
Where("account_id = ? AND peer_status_last_seen BETWEEN ? AND ?", accountID, start, end).
Distinct("user_id").
Count(&stats.ActiveUsers).Error
if err != nil {
return fmt.Errorf("get active users: %w", err)
}
err = tx.Model(&User{}).
Where("account_id = ? AND is_service_user = ?", accountID, false).
Count(&stats.TotalUsers).Error
if err != nil {
return fmt.Errorf("get total users: %w", err)
}
err = tx.Model(&nbpeer.Peer{}).
Where("account_id = ? AND peer_status_last_seen BETWEEN ? AND ?", accountID, start, end).
Count(&stats.ActivePeers).Error
if err != nil {
return fmt.Errorf("get active peers: %w", err)
}
err = tx.Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Count(&stats.TotalPeers).Error
if err != nil {
return fmt.Errorf("get total peers: %w", err)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("transaction: %w", err)
}
return stats, nil
}

View File

@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"path/filepath" "path/filepath"
@ -346,3 +347,29 @@ func newAccount(store Store, id int) error {
return store.SaveAccount(account) return store.SaveAccount(account)
} }
func TestSqliteStore_CalculateUsageStats(t *testing.T) {
store := newSqliteStoreFromFile(t, "testdata/store_stats.json")
t.Cleanup(func() {
require.NoError(t, store.Close())
})
startDate := time.Date(2024, time.February, 1, 0, 0, 0, 0, time.UTC)
endDate := startDate.AddDate(0, 1, 0).Add(-time.Nanosecond)
stats1, err := store.CalculateUsageStats(context.TODO(), "account-1", startDate, endDate)
require.NoError(t, err)
assert.Equal(t, int64(2), stats1.ActiveUsers)
assert.Equal(t, int64(4), stats1.TotalUsers)
assert.Equal(t, int64(3), stats1.ActivePeers)
assert.Equal(t, int64(7), stats1.TotalPeers)
stats2, err := store.CalculateUsageStats(context.TODO(), "account-2", startDate, endDate)
require.NoError(t, err)
assert.Equal(t, int64(1), stats2.ActiveUsers)
assert.Equal(t, int64(2), stats2.TotalUsers)
assert.Equal(t, int64(1), stats2.ActivePeers)
assert.Equal(t, int64(2), stats2.TotalPeers)
}

View File

@ -1,6 +1,7 @@
package status package status
import ( import (
"errors"
"fmt" "fmt"
) )
@ -68,7 +69,8 @@ func FromError(err error) (s *Error, ok bool) {
if err == nil { if err == nil {
return nil, true return nil, true
} }
if e, ok := err.(*Error); ok { var e *Error
if errors.As(err, &e) {
return e, true return e, true
} }
return nil, false return nil, false

View File

@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -41,6 +42,7 @@ type Store interface {
// GetStoreEngine should return StoreEngine of the current store implementation. // GetStoreEngine should return StoreEngine of the current store implementation.
// This is also a method of metrics.DataSource interface. // This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine GetStoreEngine() StoreEngine
CalculateUsageStats(ctx context.Context, accountID string, start time.Time, end time.Time) (*AccountUsageStats, error)
} }
type StoreEngine string type StoreEngine string

View File

@ -0,0 +1,161 @@
{
"Accounts": {
"account-1": {
"Id": "account-1",
"Domain": "example.com",
"Network": {
"Id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
"Net": {
"IP": "100.64.0.0",
"Mask": "//8AAA=="
},
"Dns": null
},
"Users": {
"user-1-account-1": {
"Id": "user-1-account-1"
},
"user-2-account-1": {
"Id": "user-2-account-1"
},
"user-3-account-1": {
"Id": "user-3-account-1"
},
"user-4-account-1": {
"Id": "user-4-account-1"
},
"user-5-account-1": {
"Id": "user-5-account-1",
"IsServiceUser": true
}
},
"Peers": {
"peer-1-account-1": {
"ID": "peer-1-account-1",
"UserID": "user-1-account-1",
"Status": {
"LastSeen": "2024-01-01T00:00:00Z"
},
"Name": "Peer One",
"Meta": {
"Hostname": "peer1-host"
}
},
"peer-2-account-1": {
"ID": "peer-2-account-1",
"UserID": "user-2-account-1",
"Status": {
"LastSeen": "2024-02-29T23:59:59Z"
},
"Name": "Peer Two",
"Meta": {
"Hostname": "peer2-host"
}
},
"peer-3-account-1": {
"ID": "peer-3-account-1",
"UserID": "user-2-account-1",
"Status": {
"LastSeen": "2024-02-01T12:00:00Z"
},
"Name": "Peer Three",
"Meta": {
"Hostname": "peer3-host"
}
},
"peer-4-account-1": {
"ID": "peer-4-account-1",
"UserID": "user-3-account-1",
"Status": {
"LastSeen": "2024-02-08T12:00:00Z"
},
"Name": "Peer Four",
"Meta": {
"Hostname": "peer4-host"
}
},
"peer-5-account-1": {
"ID": "peer-5-account-1",
"UserID": "user-3-account-1",
"Status": {
"LastSeen": "2023-06-01T12:00:00Z"
},
"Name": "Peer Five",
"Meta": {
"Hostname": "peer5-host"
}
},
"peer-6-account-1": {
"ID": "peer-6-account-1",
"UserID": "user-4-account-1",
"Status": {
"LastSeen": "2024-01-31T23:59:59Z"
},
"Name": "Peer Six",
"Meta": {
"Hostname": "peer6-host"
}
},
"peer-7-account-1": {
"ID": "peer-7-account-1",
"UserID": "user-4-account-1",
"Status": {
"LastSeen": "2024-03-01T00:00:00Z"
},
"Name": "Peer Seven",
"Meta": {
"Hostname": "peer7-host"
}
}
}
},
"account-2": {
"Id": "account-2",
"Domain": "example.org",
"Network": {
"Id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
"Net": {
"IP": "100.64.0.0",
"Mask": "//8AAA=="
},
"Dns": null
},
"Users": {
"user-1-account-2": {
"Id": "user-1-account-2"
},
"user-2-account-2": {
"Id": "user-1-account-2"
},
"user-3-account-2": {
"Id": "user-3-account-2",
"IsServiceUser": true
}
},
"Peers": {
"peer-1-account-2": {
"ID": "peer-1-account-2",
"UserID": "user-1-account-2",
"Status": {
"LastSeen": "2023-08-30T12:00:00Z"
},
"Name": "Peer One",
"Meta": {
"Hostname": "peer1-host"
}
},
"peer-2-account-2": {
"ID": "peer-2-account-2",
"UserID": "user-1-account-2",
"Status": {
"LastSeen": "2024-02-08T12:00:00Z"
},
"Name": "Peer Two",
"Meta": {
"Hostname": "peer2-host"
}
}
}
}
}
}