Add initial support of device posture checks (#1540)

This PR implements the following posture checks:

* Agent minimum version allowed
* OS minimum version allowed
* Geo-location based on connection IP

For the geo-based location, we rely on GeoLite2 databases which are free IP geolocation databases. MaxMind was tested and we provide a script that easily allows to download of all necessary files, see infrastructure_files/download-geolite2.sh.

The OpenAPI spec should extensively cover the life cycle of current version posture checks.
This commit is contained in:
Yury Gargay
2024-02-20 09:59:56 +01:00
committed by GitHub
parent db3cba5e0f
commit 9bc7b9e897
61 changed files with 5162 additions and 348 deletions

View File

@ -1,4 +1,4 @@
openapi: 3.0.1
openapi: 3.1.0
servers:
- url: https://api.netbird.io
description: Default server
@ -21,6 +21,8 @@ tags:
description: Interact with and view information about rules.
- name: Policies
description: Interact with and view information about policies.
- name: Posture Checks
description: Interact with and view information about posture checks.
- name: Routes
description: Interact with and view information about routes.
- name: DNS
@ -245,6 +247,10 @@ components:
description: Peer's IP address
type: string
example: 10.64.0.1
connection_ip:
description: Peer's public connection IP address
type: string
example: 35.64.0.1
connected:
description: Peer to Management connection status
type: boolean
@ -258,6 +264,14 @@ components:
description: Peer's operating system and version
type: string
example: Darwin 13.2.1
kernel_version:
description: Peer's operating system kernel version
type: string
example: 23.2.0
geoname_id:
description: Unique identifier from the GeoNames database for a specific geographical location.
type: integer
example: 2643743
version:
description: Peer's daemon or cli version
type: string
@ -304,6 +318,10 @@ components:
description: (Cloud only) Indicates whether peer needs approval
type: boolean
example: true
country_code:
$ref: '#/components/schemas/CountryCode'
city_name:
$ref: '#/components/schemas/CityName'
required:
- ip
- connected
@ -774,6 +792,12 @@ components:
- $ref: '#/components/schemas/PolicyMinimum'
- type: object
properties:
source_posture_checks:
description: Posture checks ID's applied to policy source groups
type: array
items:
type: string
example: "chacdk86lnnboviihd70"
rules:
description: Policy rule object for policy UI editor
type: array
@ -786,6 +810,12 @@ components:
- $ref: '#/components/schemas/PolicyMinimum'
- type: object
properties:
source_posture_checks:
description: Posture checks ID's applied to policy source groups
type: array
items:
type: string
example: "chacdk86lnnboviihd70"
rules:
description: Policy rule object for policy UI editor
type: array
@ -793,6 +823,170 @@ components:
$ref: '#/components/schemas/PolicyRule'
required:
- rules
- source_posture_checks
PostureCheck:
type: object
properties:
id:
description: Posture check ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
name:
description: Posture check unique name identifier
type: string
example: Default
description:
description: Posture check friendly description
type: string
example: This checks if the peer is running required NetBird's version
checks:
$ref: '#/components/schemas/Checks'
required:
- id
- name
- checks
Checks:
description: List of objects that perform the actual checks
type: object
properties:
nb_version_check:
$ref: '#/components/schemas/NBVersionCheck'
os_version_check:
$ref: '#/components/schemas/OSVersionCheck'
geo_location_check:
$ref: '#/components/schemas/GeoLocationCheck'
NBVersionCheck:
description: Posture check for the version of NetBird
type: object
$ref: '#/components/schemas/MinVersionCheck'
OSVersionCheck:
description: Posture check for the version of operating system
type: object
properties:
android:
description: Minimum version of Android
$ref: '#/components/schemas/MinVersionCheck'
darwin:
$ref: '#/components/schemas/MinVersionCheck'
ios:
description: Minimum version of iOS
$ref: '#/components/schemas/MinVersionCheck'
linux:
description: Minimum Linux kernel version
$ref: '#/components/schemas/MinKernelVersionCheck'
windows:
description: Minimum Windows kernel build version
$ref: '#/components/schemas/MinKernelVersionCheck'
example:
android:
min_version: "13"
ios:
min_version: "17.3.1"
darwin:
min_version: "14.2.1"
linux:
min_kernel_version: "5.3.3"
windows:
min_kernel_version: "10.0.1234"
MinVersionCheck:
description: Posture check for the version of operating system
type: object
properties:
min_version:
description: Minimum acceptable version
type: string
example: "14.3"
required:
- min_version
MinKernelVersionCheck:
description: Posture check with the kernel version
type: object
properties:
min_kernel_version:
description: Minimum acceptable version
type: string
example: "6.6.12"
required:
- min_kernel_version
GeoLocationCheck:
description: Posture check for geo location
type: object
properties:
locations:
description: List of geo locations to which the policy applies
type: array
items:
$ref: '#/components/schemas/Location'
action:
description: Action to take upon policy match
type: string
enum: [ "allow", "deny" ]
example: "allow"
required:
- locations
- action
Location:
description: Describe geographical location information
type: object
properties:
country_code:
$ref: '#/components/schemas/CountryCode'
city_name:
$ref: '#/components/schemas/CityName'
required:
- country_code
CountryCode:
description: 2-letter ISO 3166-1 alpha-2 code that represents the country
type: string
example: "DE"
CityName:
description: Commonly used English name of the city
type: string
example: "Berlin"
Country:
description: Describe country geographical location information
type: object
properties:
country_name:
description: Commonly used English name of the country
type: string
example: "Germany"
country_code:
$ref: '#/components/schemas/CountryCode'
required:
- country_name
- country_code
City:
description: Describe city geographical location information
type: object
properties:
geoname_id:
description: Integer ID of the record in GeoNames database
type: integer
example: 2950158
city_name:
description: Commonly used English name of the city
type: string
example: "Berlin"
required:
- geoname_id
- city_name
PostureCheckUpdate:
type: object
properties:
name:
description: Posture check name identifier
type: string
example: Default
description:
description: Posture check friendly description
type: string
example: This checks if the peer is running required NetBird's version
checks:
$ref: '#/components/schemas/Checks'
required:
- name
- description
RouteRequest:
type: object
properties:
@ -2144,7 +2338,6 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/routes/{routeId}:
get:
summary: Retrieve a Route
@ -2289,7 +2482,6 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/dns/nameservers/{nsgroupId}:
get:
summary: Retrieve a Nameserver Group
@ -2381,7 +2573,6 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/dns/settings:
get:
summary: Retrieve DNS settings
@ -2459,3 +2650,194 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/posture-checks:
get:
summary: List all Posture Checks
description: Returns a list of all posture checks
tags: [ "Posture Checks" ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON Array of posture checks
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/PostureCheck'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
post:
summary: Create a Posture Check
description: Creates a posture check
tags: [ "Posture Checks" ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
requestBody:
description: New posture check request
content:
'application/json':
schema:
$ref: '#/components/schemas/PostureCheckUpdate'
responses:
'200':
description: A posture check Object
content:
application/json:
schema:
$ref: '#/components/schemas/PostureCheck'
/api/posture-checks/{postureCheckId}:
get:
summary: Retrieve a Posture Check
description: Get information about a posture check
tags: [ "Posture Checks" ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: postureCheckId
required: true
schema:
type: string
description: The unique identifier of a posture check
responses:
'200':
description: A posture check object
content:
application/json:
schema:
$ref: '#/components/schemas/PostureCheck'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
put:
summary: Update a Posture Check
description: Update/Replace a posture check
tags: [ "Posture Checks" ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: postureCheckId
required: true
schema:
type: string
description: The unique identifier of a posture check
requestBody:
description: Update Rule request
content:
'application/json':
schema:
$ref: '#/components/schemas/PostureCheckUpdate'
responses:
'200':
description: A posture check object
content:
application/json:
schema:
$ref: '#/components/schemas/PostureCheck'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
delete:
summary: Delete a Posture Check
description: Delete a posture check
tags: [ "Posture Checks" ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: postureCheckId
required: true
schema:
type: string
description: The unique identifier of a posture check
responses:
'200':
description: Delete status code
content: { }
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/locations/countries:
get:
summary: List all country codes
description: Get list of all country in 2-letter ISO 3166-1 alpha-2 codes
tags: [ "Geo Locations" ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: List of country codes
content:
application/json:
schema:
type: array
items:
type: string
example: "DE"
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/locations/countries/{country}/cities:
get:
summary: List all city names by country
description: Get a list of all English city names for a given country code
tags: [ "Geo Locations" ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: country
required: true
schema:
$ref: '#/components/schemas/Country'
responses:
'200':
description: List of city names
content:
application/json:
schema:
$ref: '#/components/schemas/City'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"

View File

@ -63,6 +63,12 @@ const (
EventActivityCodeUserUnblock EventActivityCode = "user.unblock"
)
// Defines values for GeoLocationCheckAction.
const (
GeoLocationCheckActionAllow GeoLocationCheckAction = "allow"
GeoLocationCheckActionDeny GeoLocationCheckAction = "deny"
)
// Defines values for NameserverNsType.
const (
NameserverNsTypeUdp NameserverNsType = "udp"
@ -176,6 +182,40 @@ type AccountSettings struct {
PeerLoginExpirationEnabled bool `json:"peer_login_expiration_enabled"`
}
// Checks List of objects that perform the actual checks
type Checks struct {
// GeoLocationCheck Posture check for geo location
GeoLocationCheck *GeoLocationCheck `json:"geo_location_check,omitempty"`
NbVersionCheck *NBVersionCheck `json:"nb_version_check,omitempty"`
// OsVersionCheck Posture check for the version of operating system
OsVersionCheck *OSVersionCheck `json:"os_version_check,omitempty"`
}
// City Describe city geographical location information
type City struct {
// CityName Commonly used English name of the city
CityName string `json:"city_name"`
// GeonameId Integer ID of the record in GeoNames database
GeonameId int `json:"geoname_id"`
}
// CityName Commonly used English name of the city
type CityName = string
// Country Describe country geographical location information
type Country struct {
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode CountryCode `json:"country_code"`
// CountryName Commonly used English name of the country
CountryName string `json:"country_name"`
}
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
type CountryCode = string
// DNSSettings defines model for DNSSettings.
type DNSSettings struct {
// DisabledManagementGroups Groups whose DNS management is disabled
@ -215,6 +255,18 @@ type Event struct {
// EventActivityCode The string code of the activity that occurred during the event
type EventActivityCode string
// GeoLocationCheck Posture check for geo location
type GeoLocationCheck struct {
// Action Action to take upon policy match
Action GeoLocationCheckAction `json:"action"`
// Locations List of geo locations to which the policy applies
Locations []Location `json:"locations"`
}
// GeoLocationCheckAction Action to take upon policy match
type GeoLocationCheckAction string
// Group defines model for Group.
type Group struct {
// Id Group ID
@ -257,6 +309,30 @@ type GroupRequest struct {
Peers *[]string `json:"peers,omitempty"`
}
// Location Describe geographical location information
type Location struct {
// CityName Commonly used English name of the city
CityName *CityName `json:"city_name,omitempty"`
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode CountryCode `json:"country_code"`
}
// MinKernelVersionCheck Posture check with the kernel version
type MinKernelVersionCheck struct {
// MinKernelVersion Minimum acceptable version
MinKernelVersion string `json:"min_kernel_version"`
}
// MinVersionCheck defines model for MinVersionCheck.
type MinVersionCheck struct {
// MinVersion Minimum acceptable version
MinVersion string `json:"min_version"`
}
// NBVersionCheck defines model for NBVersionCheck.
type NBVersionCheck = MinVersionCheck
// Nameserver defines model for Nameserver.
type Nameserver struct {
// Ip Nameserver IP
@ -329,6 +405,19 @@ type NameserverGroupRequest struct {
SearchDomainsEnabled bool `json:"search_domains_enabled"`
}
// OSVersionCheck Posture check for the version of operating system
type OSVersionCheck struct {
Android *MinVersionCheck `json:"android,omitempty"`
Darwin *MinVersionCheck `json:"darwin,omitempty"`
Ios *MinVersionCheck `json:"ios,omitempty"`
// Linux Posture check with the kernel version
Linux *MinKernelVersionCheck `json:"linux,omitempty"`
// Windows Posture check with the kernel version
Windows *MinKernelVersionCheck `json:"windows,omitempty"`
}
// Peer defines model for Peer.
type Peer struct {
// AccessiblePeers List of accessible peers
@ -337,12 +426,24 @@ type Peer struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired *bool `json:"approval_required,omitempty"`
// CityName Commonly used English name of the city
CityName *CityName `json:"city_name,omitempty"`
// Connected Peer to Management connection status
Connected bool `json:"connected"`
// ConnectionIp Peer's public connection IP address
ConnectionIp *string `json:"connection_ip,omitempty"`
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode *CountryCode `json:"country_code,omitempty"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
// GeonameId Unique identifier from the GeoNames database for a specific geographical location.
GeonameId *int `json:"geoname_id,omitempty"`
// Groups Groups that the peer belongs to
Groups []GroupMinimum `json:"groups"`
@ -355,6 +456,9 @@ type Peer struct {
// Ip Peer's IP address
Ip string `json:"ip"`
// KernelVersion Peer's operating system kernel version
KernelVersion *string `json:"kernel_version,omitempty"`
// LastLogin Last time this peer performed log in (authentication). E.g., user authenticated.
LastLogin time.Time `json:"last_login"`
@ -391,12 +495,24 @@ type PeerBase struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired *bool `json:"approval_required,omitempty"`
// CityName Commonly used English name of the city
CityName *CityName `json:"city_name,omitempty"`
// Connected Peer to Management connection status
Connected bool `json:"connected"`
// ConnectionIp Peer's public connection IP address
ConnectionIp *string `json:"connection_ip,omitempty"`
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode *CountryCode `json:"country_code,omitempty"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
// GeonameId Unique identifier from the GeoNames database for a specific geographical location.
GeonameId *int `json:"geoname_id,omitempty"`
// Groups Groups that the peer belongs to
Groups []GroupMinimum `json:"groups"`
@ -409,6 +525,9 @@ type PeerBase struct {
// Ip Peer's IP address
Ip string `json:"ip"`
// KernelVersion Peer's operating system kernel version
KernelVersion *string `json:"kernel_version,omitempty"`
// LastLogin Last time this peer performed log in (authentication). E.g., user authenticated.
LastLogin time.Time `json:"last_login"`
@ -448,12 +567,24 @@ type PeerBatch struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired *bool `json:"approval_required,omitempty"`
// CityName Commonly used English name of the city
CityName *CityName `json:"city_name,omitempty"`
// Connected Peer to Management connection status
Connected bool `json:"connected"`
// ConnectionIp Peer's public connection IP address
ConnectionIp *string `json:"connection_ip,omitempty"`
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
CountryCode *CountryCode `json:"country_code,omitempty"`
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
// GeonameId Unique identifier from the GeoNames database for a specific geographical location.
GeonameId *int `json:"geoname_id,omitempty"`
// Groups Groups that the peer belongs to
Groups []GroupMinimum `json:"groups"`
@ -466,6 +597,9 @@ type PeerBatch struct {
// Ip Peer's IP address
Ip string `json:"ip"`
// KernelVersion Peer's operating system kernel version
KernelVersion *string `json:"kernel_version,omitempty"`
// LastLogin Last time this peer performed log in (authentication). E.g., user authenticated.
LastLogin time.Time `json:"last_login"`
@ -569,6 +703,9 @@ type Policy struct {
// Rules Policy rule object for policy UI editor
Rules []PolicyRule `json:"rules"`
// SourcePostureChecks Posture checks ID's applied to policy source groups
SourcePostureChecks []string `json:"source_posture_checks"`
}
// PolicyMinimum defines model for PolicyMinimum.
@ -713,6 +850,36 @@ type PolicyUpdate struct {
// Rules Policy rule object for policy UI editor
Rules []PolicyRuleUpdate `json:"rules"`
// SourcePostureChecks Posture checks ID's applied to policy source groups
SourcePostureChecks *[]string `json:"source_posture_checks,omitempty"`
}
// PostureCheck defines model for PostureCheck.
type PostureCheck struct {
// Checks List of objects that perform the actual checks
Checks Checks `json:"checks"`
// Description Posture check friendly description
Description *string `json:"description,omitempty"`
// Id Posture check ID
Id string `json:"id"`
// Name Posture check unique name identifier
Name string `json:"name"`
}
// PostureCheckUpdate defines model for PostureCheckUpdate.
type PostureCheckUpdate struct {
// Checks List of objects that perform the actual checks
Checks *Checks `json:"checks,omitempty"`
// Description Posture check friendly description
Description string `json:"description"`
// Name Posture check name identifier
Name string `json:"name"`
}
// Route defines model for Route.
@ -1012,6 +1179,12 @@ type PostApiPoliciesJSONRequestBody = PolicyUpdate
// PutApiPoliciesPolicyIdJSONRequestBody defines body for PutApiPoliciesPolicyId for application/json ContentType.
type PutApiPoliciesPolicyIdJSONRequestBody = PolicyUpdate
// PostApiPostureChecksJSONRequestBody defines body for PostApiPostureChecks for application/json ContentType.
type PostApiPostureChecksJSONRequestBody = PostureCheckUpdate
// PutApiPostureChecksPostureCheckIdJSONRequestBody defines body for PutApiPostureChecksPostureCheckId for application/json ContentType.
type PutApiPostureChecksPostureCheckIdJSONRequestBody = PostureCheckUpdate
// PostApiRoutesJSONRequestBody defines body for PostApiRoutes for application/json ContentType.
type PostApiRoutesJSONRequestBody = RouteRequest

View File

@ -0,0 +1,236 @@
package http
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"path"
"testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/util"
)
func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
t.Helper()
var (
mmdbPath = "../testdata/GeoLite2-City-Test.mmdb"
geonamesDBPath = "../testdata/geonames-test.db"
)
tempDir := t.TempDir()
err := util.CopyFileContents(mmdbPath, path.Join(tempDir, geolocation.MMDBFileName))
assert.NoError(t, err)
err = util.CopyFileContents(geonamesDBPath, path.Join(tempDir, geolocation.GeoSqliteDBFile))
assert.NoError(t, err)
geo, err := geolocation.NewGeolocation(tempDir)
assert.NoError(t, err)
t.Cleanup(func() { _ = geo.Stop() })
return &GeolocationsHandler{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,
Users: map[string]*server.User{
"test_user": user,
},
}, user, nil
},
},
geolocationManager: geo,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
}
}),
),
}
}
func TestGetCitiesByCountry(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
expectedCities []api.City
requestType string
requestPath string
}{
{
name: "Get cities with valid country iso code",
expectedStatus: http.StatusOK,
expectedBody: true,
expectedCities: []api.City{
{
CityName: "Souni",
GeonameId: 5819,
},
{
CityName: "Protaras",
GeonameId: 18918,
},
},
requestType: http.MethodGet,
requestPath: "/api/locations/countries/CY/cities",
},
{
name: "Get cities with valid country iso code but zero cities",
expectedStatus: http.StatusOK,
expectedBody: true,
expectedCities: make([]api.City, 0),
requestType: http.MethodGet,
requestPath: "/api/locations/countries/DE/cities",
},
{
name: "Get cities with invalid country iso code",
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
requestType: http.MethodGet,
requestPath: "/api/locations/countries/12ds/cities",
},
}
geolocationHandler := initGeolocationTestData(t)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
router := mux.NewRouter()
router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.GetCitiesByCountry).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
return
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
if !tc.expectedBody {
return
}
cities := make([]api.City, 0)
if err = json.Unmarshal(content, &cities); err != nil {
t.Fatalf("unmarshal request cities response : %v", err)
return
}
assert.ElementsMatch(t, tc.expectedCities, cities)
})
}
}
func TestGetAllCountries(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
expectedCountries []api.Country
requestType string
requestPath string
}{
{
name: "Get all countries",
expectedStatus: http.StatusOK,
expectedBody: true,
expectedCountries: []api.Country{
{
CountryCode: "IR",
CountryName: "Iran",
},
{
CountryCode: "CY",
CountryName: "Cyprus",
},
{
CountryCode: "RW",
CountryName: "Rwanda",
},
{
CountryCode: "SO",
CountryName: "Somalia",
},
{
CountryCode: "YE",
CountryName: "Yemen",
},
{
CountryCode: "LY",
CountryName: "Libya",
},
{
CountryCode: "IQ",
CountryName: "Iraq",
},
},
requestType: http.MethodGet,
requestPath: "/api/locations/countries",
},
}
geolocationHandler := initGeolocationTestData(t)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
router := mux.NewRouter()
router.HandleFunc("/api/locations/countries", geolocationHandler.GetAllCountries).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
return
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
if !tc.expectedBody {
return
}
countries := make([]api.Country, 0)
if err = json.Unmarshal(content, &countries); err != nil {
t.Fatalf("unmarshal request cities response : %v", err)
return
}
assert.ElementsMatch(t, tc.expectedCountries, countries)
})
}
}

View File

@ -0,0 +1,119 @@
package http
import (
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
)
// GeolocationsHandler is a handler that returns locations.
type GeolocationsHandler struct {
accountManager server.AccountManager
geolocationManager *geolocation.Geolocation
claimsExtractor *jwtclaims.ClaimsExtractor
}
// NewGeolocationsHandlerHandler creates a new Geolocations handler
func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *GeolocationsHandler {
return &GeolocationsHandler{
accountManager: accountManager,
geolocationManager: geolocationManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// GetAllCountries retrieves a list of all countries
func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(err, w)
return
}
if l.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
return
}
allCountries, err := l.geolocationManager.GetAllCountries()
if err != nil {
util.WriteError(err, w)
return
}
countries := make([]api.Country, 0, len(allCountries))
for _, country := range allCountries {
countries = append(countries, toCountryResponse(country))
}
util.WriteJSONObject(w, countries)
}
// GetCitiesByCountry retrieves a list of cities based on the given country code
func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
countryCode := vars["country"]
if !countryCodeRegex.MatchString(countryCode) {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid country code"), w)
return
}
if l.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
return
}
allCities, err := l.geolocationManager.GetCitiesByCountry(countryCode)
if err != nil {
util.WriteError(err, w)
return
}
cities := make([]api.City, 0, len(allCities))
for _, city := range allCities {
cities = append(cities, toCityResponse(city))
}
util.WriteJSONObject(w, cities)
}
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
claims := l.claimsExtractor.FromRequestContext(r)
_, user, err := l.accountManager.GetAccountFromToken(claims)
if err != nil {
return err
}
if !user.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "user is not allowed to perform this action")
}
return nil
}
func toCountryResponse(country geolocation.Country) api.Country {
return api.Country{
CountryName: country.CountryName,
CountryCode: country.CountryISOCode,
}
}
func toCityResponse(city geolocation.City) api.City {
return api.City{
CityName: city.CityName,
GeonameId: city.GeoNameID,
}
}

View File

@ -9,6 +9,7 @@ import (
"github.com/netbirdio/management-integrations/integrations"
s "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/telemetry"
@ -23,9 +24,10 @@ type AuthCfg struct {
}
type apiHandler struct {
Router *mux.Router
AccountManager s.AccountManager
AuthCfg AuthCfg
Router *mux.Router
AccountManager s.AccountManager
geolocationManager *geolocation.Geolocation
AuthCfg AuthCfg
}
// EmptyObject is an empty struct used to return empty JSON object
@ -33,7 +35,7 @@ type emptyObject struct {
}
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
func APIHandler(accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
@ -63,9 +65,10 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)
api := apiHandler{
Router: router,
AccountManager: accountManager,
AuthCfg: authCfg,
Router: router,
AccountManager: accountManager,
geolocationManager: LocationManager,
AuthCfg: authCfg,
}
integrations.RegisterHandlers(api.Router, accountManager, claimsExtractor)
@ -81,6 +84,8 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
api.addDNSNameserversEndpoint()
api.addDNSSettingEndpoint()
api.addEventsEndpoint()
api.addPostureCheckEndpoint()
api.addLocationsEndpoint()
err := api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error {
methods, err := route.GetMethods()
@ -200,3 +205,18 @@ func (apiHandler *apiHandler) addEventsEndpoint() {
eventsHandler := NewEventsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/events", eventsHandler.GetAllEvents).Methods("GET", "OPTIONS")
}
func (apiHandler *apiHandler) addPostureCheckEndpoint() {
postureCheckHandler := NewPostureChecksHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.GetAllPostureChecks).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.CreatePostureCheck).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.UpdatePostureCheck).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.GetPostureCheck).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.DeletePostureCheck).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addLocationsEndpoint() {
locationHandler := NewGeolocationsHandlerHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/locations/countries", locationHandler.GetAllCountries).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/locations/countries/{country}/cities", locationHandler.GetCitiesByCountry).Methods("GET", "OPTIONS")
}

View File

@ -3,6 +3,7 @@ package http
import (
"encoding/json"
"fmt"
"net"
"net/http"
"github.com/gorilla/mux"
@ -230,14 +231,30 @@ func toGroupsInfo(groups map[string]*server.Group, peerID string) []api.GroupMin
return groupsInfo
}
func connectionIPoString(ip net.IP) *string {
publicIP := ""
if ip != nil {
publicIP = ip.String()
}
return &publicIP
}
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer) *api.Peer {
osVersion := peer.Meta.OSVersion
if osVersion == "" {
osVersion = peer.Meta.Core
}
geonameID := int(peer.Location.GeoNameID)
return &api.Peer{
Id: peer.ID,
Name: peer.Name,
Ip: peer.IP.String(),
ConnectionIp: connectionIPoString(peer.Location.ConnectionIP),
Connected: peer.Status.Connected,
LastSeen: peer.Status.LastSeen,
Os: fmt.Sprintf("%s %s", peer.Meta.OS, peer.Meta.Core),
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
KernelVersion: &peer.Meta.KernelVersion,
GeonameId: &geonameID,
Version: peer.Meta.WtVersion,
Groups: groupsInfo,
SshEnabled: peer.SSHEnabled,
@ -250,17 +267,27 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
LoginExpired: peer.Status.LoginExpired,
AccessiblePeers: accessiblePeer,
ApprovalRequired: &peer.Status.RequiresApproval,
CountryCode: &peer.Location.CountryCode,
CityName: &peer.Location.CityName,
}
}
func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch {
osVersion := peer.Meta.OSVersion
if osVersion == "" {
osVersion = peer.Meta.Core
}
geonameID := int(peer.Location.GeoNameID)
return &api.PeerBatch{
Id: peer.ID,
Name: peer.Name,
Ip: peer.IP.String(),
ConnectionIp: connectionIPoString(peer.Location.ConnectionIP),
Connected: peer.Status.Connected,
LastSeen: peer.Status.LastSeen,
Os: fmt.Sprintf("%s %s", peer.Meta.OS, peer.Meta.Core),
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
KernelVersion: &peer.Meta.KernelVersion,
GeonameId: &geonameID,
Version: peer.Meta.WtVersion,
Groups: groupsInfo,
SshEnabled: peer.SSHEnabled,
@ -273,6 +300,8 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
LoginExpired: peer.Status.LoginExpired,
AccessiblePeersCount: accessiblePeersCount,
ApprovalRequired: &peer.Status.RequiresApproval,
CountryCode: &peer.Location.CountryCode,
CityName: &peer.Location.CityName,
}
}

View File

@ -206,6 +206,10 @@ func (h *Policies) savePolicy(
policy.Rules = append(policy.Rules, &pr)
}
if req.SourcePostureChecks != nil {
policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks)
}
if err := h.accountManager.SavePolicy(account.Id, user.Id, &policy); err != nil {
util.WriteError(err, w)
return
@ -284,10 +288,11 @@ func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy {
cache := make(map[string]api.GroupMinimum)
ap := &api.Policy{
Id: &policy.ID,
Name: policy.Name,
Description: policy.Description,
Enabled: policy.Enabled,
Id: &policy.ID,
Name: policy.Name,
Description: policy.Description,
Enabled: policy.Enabled,
SourcePostureChecks: policy.SourcePostureChecks,
}
for _, r := range policy.Rules {
rID := r.ID
@ -351,3 +356,17 @@ func groupMinimumsToStrings(account *server.Account, gm []string) []string {
}
return result
}
func sourcePostureChecksToStrings(account *server.Account, postureChecksIds []string) []string {
result := make([]string, 0, len(postureChecksIds))
for _, id := range postureChecksIds {
for _, postureCheck := range account.PostureChecks {
if id == postureCheck.ID {
result = append(result, id)
continue
}
}
}
return result
}

View File

@ -0,0 +1,344 @@
package http
import (
"encoding/json"
"net/http"
"regexp"
"slices"
"github.com/gorilla/mux"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
)
var (
countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$")
)
// PostureChecksHandler is a handler that returns posture checks of the account.
type PostureChecksHandler struct {
accountManager server.AccountManager
geolocationManager *geolocation.Geolocation
claimsExtractor *jwtclaims.ClaimsExtractor
}
// NewPostureChecksHandler creates a new PostureChecks handler
func NewPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *PostureChecksHandler {
return &PostureChecksHandler{
accountManager: accountManager,
geolocationManager: geolocationManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// GetAllPostureChecks list for the account
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
accountPostureChecks, err := p.accountManager.ListPostureChecks(account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
return
}
postureChecks := []*api.PostureCheck{}
for _, postureCheck := range accountPostureChecks {
postureChecks = append(postureChecks, toPostureChecksResponse(postureCheck))
}
util.WriteJSONObject(w, postureChecks)
}
// UpdatePostureCheck handles update to a posture check identified by a given ID
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
return
}
postureChecksIdx := -1
for i, postureCheck := range account.PostureChecks {
if postureCheck.ID == postureChecksID {
postureChecksIdx = i
break
}
}
if postureChecksIdx < 0 {
util.WriteError(status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
return
}
p.savePostureChecks(w, r, account, user, postureChecksID)
}
// CreatePostureCheck handles posture check creation request
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
p.savePostureChecks(w, r, account, user, "")
}
// GetPostureCheck handles a posture check Get request identified by ID
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
return
}
postureChecks, err := p.accountManager.GetPostureChecks(account.Id, postureChecksID, user.Id)
if err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, toPostureChecksResponse(postureChecks))
}
// DeletePostureCheck handles posture check deletion request
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
account, user, err := p.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid posture checks ID"), w)
return
}
if err = p.accountManager.DeletePostureChecks(account.Id, postureChecksID, user.Id); err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
}
// savePostureChecks handles posture checks create and update
func (p *PostureChecksHandler) savePostureChecks(
w http.ResponseWriter,
r *http.Request,
account *server.Account,
user *server.User,
postureChecksID string,
) {
var req api.PostureCheckUpdate
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
err := validatePostureChecksUpdate(req)
if err != nil {
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
return
}
if postureChecksID == "" {
postureChecksID = xid.New().String()
}
postureChecks := posture.Checks{
ID: postureChecksID,
Name: req.Name,
Description: req.Description,
}
if nbVersionCheck := req.Checks.NbVersionCheck; nbVersionCheck != nil {
postureChecks.Checks.NBVersionCheck = &posture.NBVersionCheck{
MinVersion: nbVersionCheck.MinVersion,
}
}
if osVersionCheck := req.Checks.OsVersionCheck; osVersionCheck != nil {
postureChecks.Checks.OSVersionCheck = &posture.OSVersionCheck{
Android: (*posture.MinVersionCheck)(osVersionCheck.Android),
Darwin: (*posture.MinVersionCheck)(osVersionCheck.Darwin),
Ios: (*posture.MinVersionCheck)(osVersionCheck.Ios),
Linux: (*posture.MinKernelVersionCheck)(osVersionCheck.Linux),
Windows: (*posture.MinKernelVersionCheck)(osVersionCheck.Windows),
}
}
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
if p.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
return
}
postureChecks.Checks.GeoLocationCheck = toPostureGeoLocationCheck(geoLocationCheck)
}
if err := p.accountManager.SavePostureChecks(account.Id, user.Id, &postureChecks); err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, toPostureChecksResponse(&postureChecks))
}
func validatePostureChecksUpdate(req api.PostureCheckUpdate) error {
if req.Name == "" {
return status.Errorf(status.InvalidArgument, "posture checks name shouldn't be empty")
}
if req.Checks == nil || (req.Checks.NbVersionCheck == nil && req.Checks.OsVersionCheck == nil &&
req.Checks.GeoLocationCheck == nil) {
return status.Errorf(status.InvalidArgument, "posture checks shouldn't be empty")
}
if req.Checks.NbVersionCheck != nil && req.Checks.NbVersionCheck.MinVersion == "" {
return status.Errorf(status.InvalidArgument, "minimum version for NetBird's version check shouldn't be empty")
}
if osVersionCheck := req.Checks.OsVersionCheck; osVersionCheck != nil {
emptyOS := osVersionCheck.Android == nil && osVersionCheck.Darwin == nil && osVersionCheck.Ios == nil &&
osVersionCheck.Linux == nil && osVersionCheck.Windows == nil
emptyMinVersion := osVersionCheck.Android != nil && osVersionCheck.Android.MinVersion == "" ||
osVersionCheck.Darwin != nil && osVersionCheck.Darwin.MinVersion == "" ||
osVersionCheck.Ios != nil && osVersionCheck.Ios.MinVersion == "" ||
osVersionCheck.Linux != nil && osVersionCheck.Linux.MinKernelVersion == "" ||
osVersionCheck.Windows != nil && osVersionCheck.Windows.MinKernelVersion == ""
if emptyOS || emptyMinVersion {
return status.Errorf(status.InvalidArgument,
"minimum version for at least one OS in the OS version check shouldn't be empty")
}
}
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
if geoLocationCheck.Action == "" {
return status.Errorf(status.InvalidArgument, "action for geolocation check shouldn't be empty")
}
allowedActions := []api.GeoLocationCheckAction{api.GeoLocationCheckActionAllow, api.GeoLocationCheckActionDeny}
if !slices.Contains(allowedActions, geoLocationCheck.Action) {
return status.Errorf(status.InvalidArgument, "action for geolocation check is not valid value")
}
if len(geoLocationCheck.Locations) == 0 {
return status.Errorf(status.InvalidArgument, "locations for geolocation check shouldn't be empty")
}
for _, loc := range geoLocationCheck.Locations {
if loc.CountryCode == "" {
return status.Errorf(status.InvalidArgument, "country code for geolocation check shouldn't be empty")
}
if !countryCodeRegex.MatchString(loc.CountryCode) {
return status.Errorf(status.InvalidArgument, "country code must be 2 letters (ISO 3166-1 alpha-2 format)")
}
}
}
return nil
}
func toPostureChecksResponse(postureChecks *posture.Checks) *api.PostureCheck {
var checks api.Checks
if postureChecks.Checks.NBVersionCheck != nil {
checks.NbVersionCheck = &api.NBVersionCheck{
MinVersion: postureChecks.Checks.NBVersionCheck.MinVersion,
}
}
if postureChecks.Checks.OSVersionCheck != nil {
checks.OsVersionCheck = &api.OSVersionCheck{
Android: (*api.MinVersionCheck)(postureChecks.Checks.OSVersionCheck.Android),
Darwin: (*api.MinVersionCheck)(postureChecks.Checks.OSVersionCheck.Darwin),
Ios: (*api.MinVersionCheck)(postureChecks.Checks.OSVersionCheck.Ios),
Linux: (*api.MinKernelVersionCheck)(postureChecks.Checks.OSVersionCheck.Linux),
Windows: (*api.MinKernelVersionCheck)(postureChecks.Checks.OSVersionCheck.Windows),
}
}
if postureChecks.Checks.GeoLocationCheck != nil {
checks.GeoLocationCheck = toGeoLocationCheckResponse(postureChecks.Checks.GeoLocationCheck)
}
return &api.PostureCheck{
Id: postureChecks.ID,
Name: postureChecks.Name,
Description: &postureChecks.Description,
Checks: checks,
}
}
func toGeoLocationCheckResponse(geoLocationCheck *posture.GeoLocationCheck) *api.GeoLocationCheck {
locations := make([]api.Location, 0, len(geoLocationCheck.Locations))
for _, loc := range geoLocationCheck.Locations {
l := loc // make G601 happy
var cityName *string
if loc.CityName != "" {
cityName = &l.CityName
}
locations = append(locations, api.Location{
CityName: cityName,
CountryCode: loc.CountryCode,
})
}
return &api.GeoLocationCheck{
Action: api.GeoLocationCheckAction(geoLocationCheck.Action),
Locations: locations,
}
}
func toPostureGeoLocationCheck(apiGeoLocationCheck *api.GeoLocationCheck) *posture.GeoLocationCheck {
locations := make([]posture.Location, 0, len(apiGeoLocationCheck.Locations))
for _, loc := range apiGeoLocationCheck.Locations {
cityName := ""
if loc.CityName != nil {
cityName = *loc.CityName
}
locations = append(locations, posture.Location{
CountryCode: loc.CountryCode,
CityName: cityName,
})
}
return &posture.GeoLocationCheck{
Action: string(apiGeoLocationCheck.Action),
Locations: locations,
}
}

View File

@ -0,0 +1,796 @@
package http
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
)
var berlin = "Berlin"
var losAngeles = "Los Angeles"
func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksHandler {
testPostureChecks := make(map[string]*posture.Checks, len(postureChecks))
for _, postureCheck := range postureChecks {
testPostureChecks[postureCheck.ID] = postureCheck
}
return &PostureChecksHandler{
accountManager: &mock_server.MockAccountManager{
GetPostureChecksFunc: func(accountID, postureChecksID, userID string) (*posture.Checks, error) {
p, ok := testPostureChecks[postureChecksID]
if !ok {
return nil, status.Errorf(status.NotFound, "posture checks not found")
}
return p, nil
},
SavePostureChecksFunc: func(accountID, userID string, postureChecks *posture.Checks) error {
postureChecks.ID = "postureCheck"
testPostureChecks[postureChecks.ID] = postureChecks
return nil
},
DeletePostureChecksFunc: func(accountID, postureChecksID, userID string) error {
_, ok := testPostureChecks[postureChecksID]
if !ok {
return status.Errorf(status.NotFound, "posture checks not found")
}
delete(testPostureChecks, postureChecksID)
return nil
},
ListPostureChecksFunc: func(accountID, userID string) ([]*posture.Checks, error) {
accountPostureChecks := make([]*posture.Checks, len(testPostureChecks))
for _, p := range testPostureChecks {
accountPostureChecks = append(accountPostureChecks, p)
}
return accountPostureChecks, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,
Users: map[string]*server.User{
"test_user": user,
},
PostureChecks: postureChecks,
}, user, nil
},
},
geolocationManager: &geolocation.Geolocation{},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
}
}),
),
}
}
func TestGetPostureCheck(t *testing.T) {
postureCheck := &posture.Checks{
ID: "postureCheck",
Name: "nbVersion",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "1.0.0",
},
},
}
osPostureCheck := &posture.Checks{
ID: "osPostureCheck",
Name: "osVersion",
Checks: posture.ChecksDefinition{
OSVersionCheck: &posture.OSVersionCheck{
Linux: &posture.MinKernelVersionCheck{
MinKernelVersion: "6.0.0",
},
Darwin: &posture.MinVersionCheck{
MinVersion: "14",
},
Ios: &posture.MinVersionCheck{
MinVersion: "",
},
},
},
}
geoPostureCheck := &posture.Checks{
ID: "geoPostureCheck",
Name: "geoLocation",
Checks: posture.ChecksDefinition{
GeoLocationCheck: &posture.GeoLocationCheck{
Locations: []posture.Location{
{
CountryCode: "DE",
CityName: "Berlin",
},
},
Action: posture.GeoLocationActionAllow,
},
},
}
tt := []struct {
name string
id string
checkName string
expectedStatus int
expectedBody bool
requestBody io.Reader
}{
{
name: "GetPostureCheck NBVersion OK",
expectedBody: true,
id: postureCheck.ID,
checkName: postureCheck.Name,
expectedStatus: http.StatusOK,
},
{
name: "GetPostureCheck OSVersion OK",
expectedBody: true,
id: osPostureCheck.ID,
checkName: osPostureCheck.Name,
expectedStatus: http.StatusOK,
},
{
name: "GetPostureCheck GeoLocation OK",
expectedBody: true,
id: geoPostureCheck.ID,
checkName: geoPostureCheck.Name,
expectedStatus: http.StatusOK,
},
{
name: "GetPostureCheck Not Found",
id: "not-exists",
expectedStatus: http.StatusNotFound,
},
}
p := initPostureChecksTestData(postureCheck, osPostureCheck, geoPostureCheck)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/posture-checks/{postureCheckId}", p.GetPostureCheck).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v",
status, tc.expectedStatus)
return
}
if !tc.expectedBody {
return
}
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
var got api.PostureCheck
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, got.Id, tc.id)
assert.Equal(t, got.Name, tc.checkName)
})
}
}
func TestPostureCheckUpdate(t *testing.T) {
str := func(s string) *string { return &s }
tt := []struct {
name string
expectedStatus int
expectedBody bool
expectedPostureCheck *api.PostureCheck
requestType string
requestPath string
requestBody io.Reader
setupHandlerFunc func(handler *PostureChecksHandler)
}{
{
name: "Create Posture Checks NB version",
requestType: http.MethodPost,
requestPath: "/api/posture-checks",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"description": "default",
"checks": {
"nb_version_check": {
"min_version": "1.2.3"
}
}
}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedPostureCheck: &api.PostureCheck{
Id: "postureCheck",
Name: "default",
Description: str("default"),
Checks: api.Checks{
NbVersionCheck: &api.NBVersionCheck{
MinVersion: "1.2.3",
},
},
},
},
{
name: "Create Posture Checks NB version with No geolocation DB",
requestType: http.MethodPost,
requestPath: "/api/posture-checks",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"description": "default",
"checks": {
"nb_version_check": {
"min_version": "1.2.3"
}
}
}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedPostureCheck: &api.PostureCheck{
Id: "postureCheck",
Name: "default",
Description: str("default"),
Checks: api.Checks{
NbVersionCheck: &api.NBVersionCheck{
MinVersion: "1.2.3",
},
},
},
setupHandlerFunc: func(handler *PostureChecksHandler) {
handler.geolocationManager = nil
},
},
{
name: "Create Posture Checks OS version",
requestType: http.MethodPost,
requestPath: "/api/posture-checks",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"description": "default",
"checks": {
"os_version_check": {
"android": {
"min_version": "9.0.0"
},
"ios": {
"min_version": "17.0"
},
"linux": {
"min_kernel_version": "6.0.0"
}
}
}
}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedPostureCheck: &api.PostureCheck{
Id: "postureCheck",
Name: "default",
Description: str("default"),
Checks: api.Checks{
OsVersionCheck: &api.OSVersionCheck{
Android: &api.MinVersionCheck{
MinVersion: "9.0.0",
},
Ios: &api.MinVersionCheck{
MinVersion: "17.0",
},
Linux: &api.MinKernelVersionCheck{
MinKernelVersion: "6.0.0",
},
},
},
},
},
{
name: "Create Posture Checks Geo Location",
requestType: http.MethodPost,
requestPath: "/api/posture-checks",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"description": "default",
"checks": {
"geo_location_check": {
"locations": [
{
"city_name": "Berlin",
"country_code": "DE"
}
],
"action": "allow"
}
}
}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedPostureCheck: &api.PostureCheck{
Id: "postureCheck",
Name: "default",
Description: str("default"),
Checks: api.Checks{
GeoLocationCheck: &api.GeoLocationCheck{
Locations: []api.Location{
{
CityName: &berlin,
CountryCode: "DE",
},
},
Action: api.GeoLocationCheckActionAllow,
},
},
},
},
{
name: "Create Posture Checks Geo Location with No geolocation DB",
requestType: http.MethodPost,
requestPath: "/api/posture-checks",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"description": "default",
"checks": {
"geo_location_check": {
"locations": [
{
"city_name": "Berlin",
"country_code": "DE"
}
],
"action": "allow"
}
}
}`)),
expectedStatus: http.StatusPreconditionFailed,
expectedBody: false,
setupHandlerFunc: func(handler *PostureChecksHandler) {
handler.geolocationManager = nil
},
},
{
name: "Create Posture Checks Invalid Check",
requestType: http.MethodPost,
requestPath: "/api/posture-checks",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"checks": {
"non_existing_check": {
"min_version": "1.2.0"
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
{
name: "Create Posture Checks Invalid Name",
requestType: http.MethodPost,
requestPath: "/api/posture-checks",
requestBody: bytes.NewBuffer(
[]byte(`{
"checks": {
"nb_version_check": {
"min_version": "1.2.0"
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
{
name: "Create Posture Checks Invalid NetBird's Min Version",
requestType: http.MethodPost,
requestPath: "/api/posture-checks",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"checks": {
"nb_version_check": {}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
{
name: "Create Posture Checks Invalid Geo Location",
requestType: http.MethodPost,
requestPath: "/api/posture-checks",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"checks": {
"geo_location_check": {}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
{
name: "Update Posture Checks NB Version",
requestType: http.MethodPut,
requestPath: "/api/posture-checks/postureCheck",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"checks": {
"nb_version_check": {
"min_version": "1.9.0"
}
}
}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedPostureCheck: &api.PostureCheck{
Id: "postureCheck",
Name: "default",
Description: str(""),
Checks: api.Checks{
NbVersionCheck: &api.NBVersionCheck{
MinVersion: "1.9.0",
},
},
},
},
{
name: "Update Posture Checks OS Version",
requestType: http.MethodPut,
requestPath: "/api/posture-checks/osPostureCheck",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"checks": {
"os_version_check": {
"linux": {
"min_kernel_version": "6.9.0"
}
}
}
}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedPostureCheck: &api.PostureCheck{
Id: "postureCheck",
Name: "default",
Description: str(""),
Checks: api.Checks{
OsVersionCheck: &api.OSVersionCheck{
Linux: &api.MinKernelVersionCheck{
MinKernelVersion: "6.9.0",
},
},
},
},
},
{
name: "Update Posture Checks OS Version with No geolocation DB",
requestType: http.MethodPut,
requestPath: "/api/posture-checks/osPostureCheck",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"checks": {
"os_version_check": {
"linux": {
"min_kernel_version": "6.9.0"
}
}
}
}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedPostureCheck: &api.PostureCheck{
Id: "postureCheck",
Name: "default",
Description: str(""),
Checks: api.Checks{
OsVersionCheck: &api.OSVersionCheck{
Linux: &api.MinKernelVersionCheck{
MinKernelVersion: "6.9.0",
},
},
},
},
setupHandlerFunc: func(handler *PostureChecksHandler) {
handler.geolocationManager = nil
},
},
{
name: "Update Posture Checks Geo Location",
requestType: http.MethodPut,
requestPath: "/api/posture-checks/geoPostureCheck",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"checks": {
"geo_location_check": {
"locations": [
{
"city_name": "Los Angeles",
"country_code": "US"
}
],
"action": "allow"
}
}
}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedPostureCheck: &api.PostureCheck{
Id: "postureCheck",
Name: "default",
Description: str(""),
Checks: api.Checks{
GeoLocationCheck: &api.GeoLocationCheck{
Locations: []api.Location{
{
CityName: &losAngeles,
CountryCode: "US",
},
},
Action: api.GeoLocationCheckActionAllow,
},
},
},
},
{
name: "Update Posture Checks Geo Location with No geolocation DB",
requestType: http.MethodPut,
requestPath: "/api/posture-checks/geoPostureCheck",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"checks": {
"geo_location_check": {
"locations": [
{
"city_name": "Los Angeles",
"country_code": "US"
}
],
"action": "allow"
}
}
}`)),
expectedStatus: http.StatusPreconditionFailed,
expectedBody: false,
setupHandlerFunc: func(handler *PostureChecksHandler) {
handler.geolocationManager = nil
},
},
{
name: "Update Posture Checks Geo Location with not valid action",
requestType: http.MethodPut,
requestPath: "/api/posture-checks/geoPostureCheck",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"checks": {
"geo_location_check": {
"locations": [
{
"city_name": "Los Angeles",
"country_code": "US"
}
],
"action": "not-valid"
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedBody: false,
setupHandlerFunc: func(handler *PostureChecksHandler) {
handler.geolocationManager = nil
},
},
{
name: "Update Posture Checks Invalid Check",
requestType: http.MethodPut,
requestPath: "/api/posture-checks/postureCheck",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"checks": {
"non_existing_check": {
"min_version": "1.2.0"
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
{
name: "Update Posture Checks Invalid Name",
requestType: http.MethodPut,
requestPath: "/api/posture-checks/postureCheck",
requestBody: bytes.NewBuffer(
[]byte(`{
"checks": {
"nb_version_check": {
"min_version": "1.2.0"
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
{
name: "Update Posture Checks Invalid NetBird's Min Version",
requestType: http.MethodPut,
requestPath: "/api/posture-checks/postureCheck",
requestBody: bytes.NewBuffer(
[]byte(`{
"name": "default",
"checks": {
"nb_version_check": {}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
}
p := initPostureChecksTestData(&posture.Checks{
ID: "postureCheck",
Name: "postureCheck",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "1.0.0",
},
},
},
&posture.Checks{
ID: "osPostureCheck",
Name: "osPostureCheck",
Checks: posture.ChecksDefinition{
OSVersionCheck: &posture.OSVersionCheck{
Linux: &posture.MinKernelVersionCheck{
MinKernelVersion: "5.0.0",
},
},
},
},
&posture.Checks{
ID: "geoPostureCheck",
Name: "geoLocation",
Checks: posture.ChecksDefinition{
GeoLocationCheck: &posture.GeoLocationCheck{
Locations: []posture.Location{
{
CountryCode: "DE",
CityName: "Berlin",
},
},
Action: posture.GeoLocationActionDeny,
},
},
},
)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
defaultHandler := *p
if tc.setupHandlerFunc != nil {
tc.setupHandlerFunc(&defaultHandler)
}
router := mux.NewRouter()
router.HandleFunc("/api/posture-checks", defaultHandler.CreatePostureCheck).Methods("POST")
router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.UpdatePostureCheck).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
return
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
if !tc.expectedBody {
return
}
expected, err := json.Marshal(tc.expectedPostureCheck)
if err != nil {
t.Fatalf("marshal expected posture check: %v", err)
return
}
assert.Equal(t, strings.Trim(string(content), " \n"), string(expected), "content mismatch")
})
}
}
func TestPostureCheck_validatePostureChecksUpdate(t *testing.T) {
// empty name
err := validatePostureChecksUpdate(api.PostureCheckUpdate{})
assert.Error(t, err)
// empty checks
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default"})
assert.Error(t, err)
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{}})
assert.Error(t, err)
// not valid NbVersionCheck
nbVersionCheck := api.NBVersionCheck{}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{NbVersionCheck: &nbVersionCheck}})
assert.Error(t, err)
// valid NbVersionCheck
nbVersionCheck = api.NBVersionCheck{MinVersion: "1.0"}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{NbVersionCheck: &nbVersionCheck}})
assert.NoError(t, err)
// not valid OsVersionCheck
osVersionCheck := api.OSVersionCheck{}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.Error(t, err)
// not valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{}}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.Error(t, err)
// not valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{}, Darwin: &api.MinVersionCheck{MinVersion: "14.2"}}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.Error(t, err)
// valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{MinKernelVersion: "6.0"}}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.NoError(t, err)
// valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{
Linux: &api.MinKernelVersionCheck{MinKernelVersion: "6.0"},
Darwin: &api.MinVersionCheck{MinVersion: "14.2"},
}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.NoError(t, err)
}